Skip to content

Tree based model

On top of the TabularModel, in the aindo.rdml.synth package there is another model that can be used to generate synthetic data in the case of a single table, XgbModel. While TabularModel and TextModel are neural models, XgbModel is based on XGBoost.

XgbModel is extremely simple to use, and can provide good results, with a relatively short training time. However, as already mentioned, it is limited to the case of a single table.

To initialize a XgbModel, it is necessary to provide a Schema. This must contain at most one table which is not a lookup table. The optional arguments are:

  • overwrites: Overwrites to the default table preprocessor, in the form of a dictionary with keys the column names and values the column preprocessors.
  • ctx_cols: The columns to be used as context.
  • n_estimators: The number of estimators for the XGBoost models.
  • valid_frac: The fraction of the training data to be used for validation.
  • **kwargs: Keyword arguments to be passed to the XGBoost models.

The model must then be fitted on the whole data with the XgbModel.fit() method, and later trained on the training data with the XgbModel.train() method. Finally, it can be used to generate data, with the XgbModel.generate() method. As for TabularModel, the latter has two mutually exclusive arguments:

  • n_samples: The number of samples to generate.
  • ctx: A pandas.DataFrame containing the context columns from which to start the conditional data generation. The provided columns must match the ones declared on the ctx_cols argument of the XgbModel.__init__() method.

A XgbModel can be saved with the XgbModel.save() method and loaded with the XgbModel.load() class method, in the same fashion as the TabularModel.

In the following we show an example of a complete workflow:

import pandas as pd
from aindo.rdml.relational import RelationalData, Schema, Table
from aindo.rdml.synth import XgbModel

df = pd.DataFrame(...)

data = RelationalData(data={"table": df}, schema=Schema(table=Table(...)))
data_train, data_test = data.split(ratio=0.1)

model = XgbModel(
    schema=data.schema,
    n_estimators=200,
    valid_frac=0.1,
)
model.fit(data=data)
model.train(data=data_train)
model.save(path="path/to/ckpt")

model = XgbModel.load(path="path/to/ckpt")
data_synth = model.generate(n=df.shape[0])