Skip to content

Predictive mode

Generating synthetic data from context is essentially the same as predicting specific columns of the relational data, based on the columns provided in the context. However, instead of producing a single prediction for a given input, we may sometimes want to understand how the predictions are statistically distributed.

For a given context—i.e., a set of columns provided as input to the generation process—we aim to predict the probability of each possible output, specifically the probability of each combination of values for all the generated columns. As the number of generated columns increases, the problem quickly becomes intractable, as the number of potential output combinations grows combinatorially. This challenge is further compounded if the columns to predict include numerical ones with continuous values.

A tabular generative model can help address this issue, and we refer to this as the predictive mode. A TabularModel can operate in predictive mode through two methods, TabularModel.predict_sample() and TabularModel.predict_proba().

Predict sample

The TabularModel.predict_sample() method provides an approximate solution to the prediction problem. It is very general and can be applied to virtually any scenarios.

Since this is a predictive task, the TabularPreproc (and consequently the TabularModel) must be defined with a context as described in the section on generation from context. In the multi-table case, the context must include the relation structure (i.e., the primary and foreign keys).

Given the input context ctx and an integer n, this method returns a PredSample object, which is a list of n RelationalData objects, each containing the columns predicted from the context. All predicted samples will differ due to the intrinsic randomness of the generation process. This output is equivalent to running a for loop with n separate generations from ctx, but internally, the operation is optimized to process the context once.

The n predictions in the PredSample object can be treated as a statistical sample, which can be used to estimate expectation values or any other statistical properties. The larger the number of samples n, the more accurate the estimate will be. However, computational time increases linearly with n, aside from a constant time spent processing the context.

For example, consider the case of two tables:

  • A parent table containing personal information about clients.
  • A child table containing transactions made by each client.

The training set contains a boolean target column in the child table, which indicates whether each transaction is fraudulent. Now, imagine we want to predict whether each transaction in the test set is fraudulent. This is a difficult task, as each transaction is dependent on the previous ones (a factor considered by generative tabular models). Therefore, the prediction for whether a transaction is fraudulent or not is correlated with predictions for all other transactions for the same client. The TabularModel.predict_sample() method allows to handle this problem in a general way, by directly sampling from the distribution of the target column. Once we have the PredSample object, we can determine the probability that a specific transaction for a specific client is fraudulent by counting how many times that transaction was flagged as fraudulent across the samples.

More complex scenarios with more intricate relational structures are also possible. In these cases, there may be more than one target columns, and they may belong to different tables, and the statistical properties to be computed on the output sample can be as complex as necessary. This is the key advantage of the TabularModel.predict_sample() method.

As mentioned earlier, the TabularModel.predict_sample() method requires two parameters: the context ctx and the number of samples n. Optional parameters include:

  • batch_size: An integer representing the batch size used during generation. When set to 0 (the default), all data is generated in a single batch.
  • temp: A positive number that controls the amount of noise in the generation process. Its default value is 1.
  • rng: A torch.Generator or an integer seed to fix the randomness when sampling.
import pandas as pd

from aindo.rdml.relational import Column, ForeignKey, PrimaryKey, RelationalData, Schema, Table
from aindo.rdml.synth import TabularDataset, TabularModel, TabularPreproc, TabularTrainer

# Define teh data, and split train/test
data = {
    "clients": pd.DataFrame({"id": ..., "account_type": ...}),
    "transactions": pd.DataFrame({
        "trans_id": ...,
        "client_id": ...,
        "date": ...,
        "amount": ...,
        "trans_type": ...,
        "fraudulent": ...,
    })
}
schema = Schema(
    clients=Table(id=PrimaryKey(), account_type=Column.CATEGORICAL),
    transactions=Table(
        trans_id=PrimaryKey(),
        client_id=ForeignKey(parent="clients"),
        date=Column.DATETIME,
        amount=Column.NUMERIC,
        trans_type=Column.CATEGORICAL,
        fraudulent=Column.BOOLEAN,
    ),
)
data = RelationalData(data=data, schema=schema)
data_train, data_test = data.split(ratio=0.1)

# Define the context in the TabularPreproc excluding only the target column
preproc = TabularPreproc.from_schema(
    schema=data.schema,
    ctx_cols={
        "clients": ["account_type"],
        "transactions": ["date", "amount", "trans_type"],
    },
)
preproc.fit(data=data)
model = TabularModel.build(preproc=preproc, size=...)

# Train the tabular model on the train data
dataset = TabularDataset.from_data(data=data_train, preproc=preproc)
trainer = TabularTrainer(model=model)
trainer.train(dataset=dataset, n_epochs=..., memory=..., valid=...)

# Select the context from the test data and predict the target column
ctx = preproc.select_ctx(data=data_test)
pred_sample = model.predict_sample(
    ctx=ctx,
    n=1_000,
    batch_size=32,
)

# Example: Get how many fraudulent transaction there are on average per user
n_fraud = [
    p["transactions"]
    .groupby("client_id")
    .sum()
    .reindex(p["clients"].loc[:, "id"], fill_value=0)
    for p in pred_sample
]
n_fraud_ave = sum(n_fraud[1:], start=n_fraud[0]) / len(n_fraud)

Predict probabilities

The TabularModel.predict_proba() method complements the TabularModel.predict_sample() method. To maintain its generality, TabularModel.predict_sample() relies on an approximate solution to the prediction problem. In contrast, TabularModel.predict_proba() offers an exact solution. However, it is specifically designed to predict a single categorical target column in the root table. This is a scenario where a complete statistical solution to the predictive task is possible.

The input context ctx must include all columns except for a single categorical column in the root table. For each row in the root table, the TabularModel.predict_proba() method computes the probabilities of each category in the target column. The output is a PredProb object, which contains two attributes:

  • prob: A torch.Tensor of shape (n_samples, n_categories) containing the probabilities for each category in the target column, for each sample in the root table.
  • categories: A list of the target column's categories, ordered the same as in the prob tensor.

For example, consider a dataset with several tables:

  • A parent table containing personal information about users of a site.
  • A set of child tables containing different actions performed by each user (e.g., likes, purchases, ...).

Suppose the task is to predict whether a user will keep their subscription, cancel it, or upgrade to a premium plan. In the training set, there will be a categorical target column indicating the final status of each user. The context in the TabularPreproc (and consequently in the TabularModel) should include all columns except the target column. With the TabularModel.predict_proba() method, we can directly predict the probabilities of each final status for each user.

In addition to the required ctx parameter, the TabularModel.predict_proba() method has one optional parameter, batch_size, which works the same way as in the other generative methods.

import pandas as pd

from aindo.rdml.relational import Column, ForeignKey, PrimaryKey, RelationalData, Schema, Table
from aindo.rdml.synth import TabularDataset, TabularModel, TabularPreproc, TabularTrainer

# Define teh data, and split train/test
data = {
    "users": pd.DataFrame({"id": ..., "age": ..., "location": ..., "status": ...}),
    "likes": pd.DataFrame({"user_id": ..., "date": ..., "content_liked": ...}),
    "purchases": pd.DataFrame({"user_id": ..., "date": ..., "item": ..., "amount": ...}),
}
schema = Schema(
    users=Table(id=PrimaryKey(), age=Column.INTEGER, location=Column.CATEGORICAL, status=Column.CATEGORICAL),
    likes=Table(user_id=ForeignKey(parent="users"), date=Column.DATETIME, content_liked=Column.CATEGORICAL),
    purchases=Table(user_id=ForeignKey(parent="users"), date=Column.DATETIME, item=Column.CATEGORICAL, amount=Column.NUMERIC),
)
data = RelationalData(data=data, schema=schema)
data_train, data_test = data.split(ratio=0.1)

# Define the context in the TabularPreproc excluding only the target column
preproc = TabularPreproc.from_schema(
    schema=data.schema,
    ctx_cols={
        "users": ["age", "location"],
        "likes": ["date", "content_liked"],
        "purchases": ["date", "item", "amount"],
    },
)
preproc.fit(data=data)
model = TabularModel.build(preproc=preproc, size=...)

# Train the tabular model on the train data
dataset = TabularDataset.from_data(data=data_train, preproc=preproc)
trainer = TabularTrainer(model=model)
trainer.train(dataset=dataset, n_epochs=..., memory=..., valid=...)

# Select the context from the test data and predict the probabilities for the target column
ctx = preproc.select_ctx(data=data_test)
pred_proba = model.predict_proba(
    ctx=ctx,
    batch_size=32,
)

# Obtain the probability for the category "cancelled" for each user -> tensor of shape: (data_test["users"].shape[0],)
cancel_proba = pred_proba.prob[:, pred_proba.categories.index("cancelled")]