Models
TabularModel
build
classmethod
build(
preproc: TabularPreproc,
size: str | Size | TabularModelSize,
block: str | None = None,
dropout: float | None = 0.12,
) -> TabularModel
Tabular model to generate synthetic tabular relational data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preproc
|
TabularPreproc
|
A |
required |
size
|
str | Size | TabularModelSize
|
The size configuration of the model. Could be either a |
required |
block
|
str | None
|
The block type. The possible values depend on whether the data is single table or multi table. For a single table, either 'free' (default), 'causal', or 'lstm'. For multi table data, either 'free' (default) or 'lstm'. |
None
|
dropout
|
float | None
|
The dropout probability. |
0.12
|
Returns:
Type | Description |
---|---|
TabularModel
|
A |
generate
generate(
n_samples: int | None = None,
ctx: dict[str, DataFrame] | None = None,
batch_size: int = 0,
max_block_size: int = 0,
temp: float = 1.0,
) -> RelationalData
Generate synthetic relational data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_samples
|
int | None
|
Desired number of samples in the root table. Must be given if and only if |
None
|
ctx
|
dict[str, DataFrame] | None
|
The context from where to start a conditional generation. If provided, |
None
|
batch_size
|
int
|
Batch size used during generation. If 0, all data is generated in a single batch. |
0
|
max_block_size
|
int
|
Maximum length for each generated sample. Active only for multi-table datasets and for generation from the root table (denoted above as 1.). If 0, no limit is enforced. |
0
|
temp
|
float
|
Temperature parameter for sampling. |
1.0
|
Returns:
Type | Description |
---|---|
RelationalData
|
A |
predict_proba
Predict probabilities for each category of a single categorical column. In order to use this function, the context must contain all columns except for a single categorical column in the root table.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ctx
|
dict[str, DataFrame]
|
The input context. Must contain all columns except for a single categorical column in the root table. |
required |
batch_size
|
int
|
Batch size used during prediction. If 0, all predictions are performed in a single batch. |
0
|
Returns:
Type | Description |
---|---|
PredProb
|
A |
predict_sample
predict_sample(
ctx: dict[str, DataFrame],
n: int,
batch_size: int = 0,
temp: float = 1.0,
rng: Generator | int | None = None,
) -> PredSample
Make n
prediction samples from a given context.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ctx
|
dict[str, DataFrame]
|
The context from where to compute the |
required |
n
|
int
|
The number of prediction samples. |
required |
batch_size
|
int
|
Batch size used during prediction. If 0, all predictions are performed in a single batch. |
0
|
temp
|
float
|
Temperature parameter for sampling. |
1.0
|
rng
|
Generator | int | None
|
A |
None
|
Returns:
Type | Description |
---|---|
PredSample
|
A |
PredSample
|
the predicted columns). The common context can be retrieved from the |
save
load
classmethod
load(path: Path | str) -> TabularModel
Load the TabularModel
from the checkpoint at the given path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
Path | str
|
The path to the loaded checkpoint. |
required |
Returns:
Type | Description |
---|---|
TabularModel
|
The loaded |
TextModel
build
classmethod
build(
preproc: TextPreproc,
size: str | Size | TextModelSize,
block_size: int,
dropout: float | None = 0.12,
) -> TextModel
Text model to generate synthetic text columns of a table which is part of a relational structure.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preproc
|
TextPreproc
|
A |
required |
size
|
str | Size | TextModelSize
|
The size configuration of the model. Could be either a |
required |
block_size
|
int
|
Maximum text sequence length that the model can process. |
required |
dropout
|
float | None
|
The dropout probability. |
0.12
|
Returns:
Type | Description |
---|---|
TextModel
|
A |
build_from_pretrained
classmethod
build_from_pretrained(
preproc: TextPreproc,
path: Path | str,
block_size: int | None = None,
) -> TextModel
Build a text model from a pretrained model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preproc
|
TextPreproc
|
A |
required |
path
|
Path | str
|
The path to the checkpoint of the pre-trained model. |
required |
block_size
|
int | None
|
Maximum text sequence length that the model can process during fine-tuning. |
None
|
Returns:
Type | Description |
---|---|
TextModel
|
A |
generate
generate(
data: RelationalData,
batch_size: int = 0,
max_text_len: int = 0,
temp: float = 1.0,
) -> RelationalData
Generate text columns in the current table.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data
|
RelationalData
|
A |
required |
batch_size
|
int
|
Batch size used during generation. If 0, generate all data in a single batch. |
0
|
max_text_len
|
int
|
Maximum length for the generated text. If 0, the maximum possible value is used,
namely the value of the |
0
|
temp
|
float
|
Temperature parameter for sampling. |
1.0
|
Returns:
Type | Description |
---|---|
RelationalData
|
A |
save
Size
Enumeration class representing different model sizes. Supported sizes are: SMALL, MEDIUM and LARGE.
TabularModelSize
dataclass
TextModelSize
dataclass
PredProb
dataclass
PredSample
The predicted samples. A list of RelationalData
objects containing the predicted columns.
It supports the list method list.append
and the +
operator.
Attributes:
Name | Type | Description |
---|---|---|
ctx |
RelationalData
|
A |
n_samples |
int | None
|
The number of samples in each prediction. |
schema |
Schema | None
|
The |
select
select(idx: Sequence[int]) -> PredSample
Select the predictions corresponding to the input indices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx
|
Sequence[int]
|
A |
required |
Returns:
Type | Description |
---|---|
PredSample
|
A |
XgbModel
__init__
__init__(
schema: Schema,
ctx_cols: Sequence[str] = (),
preprocessors: dict[
str, ColumnPreproc | ArColumn | None
]
| None = None,
n_estimators: int | None = 1000,
valid_frac: float | None = 0.0,
**kwargs: Any,
) -> None
A generative model based on autoregressive XGBoost models. Can be used only with single-table data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
schema
|
Schema
|
The |
required |
ctx_cols
|
Sequence[str]
|
The columns to be used as context. |
()
|
preprocessors
|
dict[str, ColumnPreproc | ArColumn | None] | None
|
A dictionary containing preprocessing instructions for each column in the table.
Preprocessing instructions can be instances of |
None
|
n_estimators
|
int | None
|
Number of estimators for the XGBoost models. |
1000
|
valid_frac
|
float | None
|
Fraction of the training data to be used for validation. |
0.0
|
**kwargs
|
Any
|
Keyword arguments to be passed to the XGBoost models. |
{}
|
fit
fit(data: RelationalData) -> XgbModel
Fit the XgbModel
to the given RelationalData
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data
|
RelationalData
|
The |
required |
Returns:
Type | Description |
---|---|
XgbModel
|
The fitted instance of the |
train
train(data: RelationalData) -> XgbModel
Train the XgbModel
with the input RelationalData
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data
|
RelationalData
|
The training data, as a |
required |
Returns:
Type | Description |
---|---|
XgbModel
|
The trained instance of the |
generate
generate(
n_samples: int | None = None,
ctx: DataFrame | None = None,
batch_size: int = 0,
temp: float = 1.0,
) -> RelationalData
Generate synthetic data.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_samples
|
int | None
|
The desired number of samples. Must be given if and only if |
None
|
ctx
|
DataFrame | None
|
The columns of the context from where to start a conditional generation. If provided,
|
None
|
batch_size
|
int
|
Batch size used during generation. If 0, all data is generated in a single batch. |
0
|
temp
|
float
|
Temperature parameter for sampling. |
1.0
|
Returns:
Type | Description |
---|---|
RelationalData
|
A |