Skip to content

Event model - BasketballMan dataset

In the following we present an example script using the aindo.rdml.synth.event package to generate synthetic event data. We make use of the BasketballMen dataset, which is the same dataset we used in the multi-table example. After training the model, we show how to generate synthetic events from scratch, and how to continue existing event series.

from pathlib import Path

import pandas as pd

from aindo.rdml.relational import Column, ForeignKey, PrimaryKey, RelationalData, Schema, Table
from aindo.rdml.synth import Validation
from aindo.rdml.synth.event import EventDataset, EventModel, EventPreproc, EventTrainer

# Data and output
DATA_DIR = Path("path/to/data/dir")
OUTPUT_DIR = Path("./output")
# Model settings
MODEL_SIZE = "small"
DEVICE = None  # Device to None means it will be set to CUDA if the latter is available, otherwise CPU
# Training settings
N_EPOCHS = 1_000  # One and only one between N_EPOCHS and N_STEPS should be an integer, and the other should be None.
N_STEPS = None
MEMORY = 4096
VALID_EACH = 200

# Load data and define schema
data = {
    "players": pd.read_csv(DATA_DIR / "players.csv"),
    "season": pd.read_csv(DATA_DIR / "season.csv"),
    "all_star": pd.read_csv(DATA_DIR / "all_star.csv"),
}
schema = Schema(
    players=Table(
        playerID=PrimaryKey(),
        pos=Column.CATEGORICAL,
        height=Column.NUMERIC,
        weight=Column.NUMERIC,
        college=Column.CATEGORICAL,
        race=Column.CATEGORICAL,
        birthCity=Column.CATEGORICAL,
        birthState=Column.CATEGORICAL,
        birthCountry=Column.CATEGORICAL,
    ),
    season=Table(
        playerID=ForeignKey(parent="players"),
        year=Column.INTEGER,
        stint=Column.INTEGER,
        tmID=Column.CATEGORICAL,
        lgID=Column.CATEGORICAL,
        GP=Column.INTEGER,
        points=Column.INTEGER,
        GS=Column.INTEGER,
        assists=Column.INTEGER,
        steals=Column.INTEGER,
        minutes=Column.INTEGER,
    ),
    all_star=Table(
        playerID=ForeignKey(parent="players"),
        season_id=Column.INTEGER,
        conference=Column.CATEGORICAL,
        league_id=Column.CATEGORICAL,
        points=Column.INTEGER,
        rebounds=Column.INTEGER,
        assists=Column.INTEGER,
        blocks=Column.INTEGER,
    ),
)
data = RelationalData(data=data, schema=schema)

# Define preprocessor
preproc = EventPreproc.from_schema(
    schema=schema,
    ord_cols={"season": "year", "all_star": "season_id"},
).fit(data=data)

# Split data
split_ratio = 0.1
data_train_valid, data_test = data.split(ratio=split_ratio)
data_train, data_valid = data_train_valid.split(ratio=split_ratio)

# Build model
model = EventModel.build(preproc=preproc, size=MODEL_SIZE)
model.device = DEVICE

# Train the model
dataset_train = EventDataset.from_data(data=data_train, preproc=preproc, on_disk=True)
dataset_valid = EventDataset.from_data(data=data_valid, preproc=preproc)
trainer = EventTrainer(model=model)
trainer.train(
    dataset=dataset_train,
    n_epochs=N_EPOCHS,
    n_steps=N_STEPS,
    memory=MEMORY,
    valid=Validation(
        dataset=dataset_valid,
        early_stop="normal",
        save_best=OUTPUT_DIR / "best.pt",
        tensorboard=OUTPUT_DIR / "tb",
        each=VALID_EACH,
        trigger="step",
    ),
)

# Generate synthetic events from scratch
data_synth = model.generate(
    n_samples=data["players"].shape[0],
    batch_size=512,
)
data_synth.to_csv(OUTPUT_DIR / "synth")

# Continue the time series in the test set, up to N events
data_synth_continue = model.generate(
    ctx=data_test,
    force_event=True,
    max_n_events=10,
    batch_size=512,
)
data_synth_continue.to_csv(OUTPUT_DIR / "synth-continue")