Tree based model
On top of the TabularModel
, in the aindo.rdml.synth package
there is another model that can be used to generate synthetic data in the case of a single table,
XgbModel
.
While TabularModel
and TextModel
are neural models,
XgbModel
is based on XGBoost.
XgbModel
is extremely simple to use, and can provide good results, with a relatively
short training time.
However, as already mentioned, it is limited to the case of a single table.
To initialize a XgbModel
, it is necessary to provide
a Schema
.
This must contain at most one table which is not a lookup table.
The optional arguments are:
overwrites
: Overwrites to the default table preprocessor, in the form of a dictionary with keys the column names and values the column preprocessors.ctx_cols
: The columns to be used as context.n_estimators
: The number of estimators for the XGBoost models.valid_frac
: The fraction of the training data to be used for validation.**kwargs
: Keyword arguments to be passed to the XGBoost models.
The model must then be fitted on the whole data with the XgbModel.fit()
method,
and later trained on the training data with the XgbModel.train()
method.
Finally, it can be used to generate data, with the XgbModel.generate()
method.
As for TabularModel
, the latter has two mutually exclusive arguments:
n_samples
: The number of samples to generate.ctx
: Apandas.DataFrame
containing the context columns from which to start the conditional data generation. The provided columns must match the ones declared on thectx_cols
argument of theXgbModel.__init__()
method.
A XgbModel
can be saved with the XgbModel.save()
method and loaded with the XgbModel.load()
class method,
in the same fashion as the TabularModel
.
In the following we show an example of a complete workflow:
import pandas as pd
from aindo.rdml.relational import RelationalData, Schema, Table
from aindo.rdml.synth import XgbModel
df = pd.DataFrame(...)
data = RelationalData(data={"table": df}, schema=Schema(table=Table(...)))
data_train, data_test = data.split(ratio=0.1)
model = XgbModel(
schema=data.schema,
n_estimators=200,
valid_frac=0.1,
)
model.fit(data=data)
model.train(data=data_train)
model.save(path="path/to/ckpt")
model = XgbModel.load(path="path/to/ckpt")
data_synth = model.generate(n=df.shape[0])