Skip to content

Event model - BasketballMan dataset

In the following we present an example script using the aindo.rdml.synth.event package to generate synthetic event data. We make use of the BasketballMen dataset, which is the same dataset we used in the multi-table example. After training the model, we show how to generate synthetic events from scratch, and how to continue existing event series.

import argparse
from pathlib import Path

import pandas as pd
import torch

from aindo.rdml.relational import Column, ForeignKey, PrimaryKey, RelationalData, Schema, Table
from aindo.rdml.synth import Size, TabularModelSize, Validation
from aindo.rdml.synth.event import EventDataset, EventModel, EventPreproc, EventTrainer


def example_basket_event(
    data_dir: Path,
    output_dir: Path,
    data_frac: float | None,
    model_size: Size | TabularModelSize | str,
    n_epochs: int | None,
    n_steps: int | None,
    valid_each: int,
    device: str | torch.device | None,
    memory: int,
) -> None:
    # Load data and define schema
    data = {
        "players": pd.read_csv(data_dir / "players.csv"),
        "season": pd.read_csv(data_dir / "season.csv"),
        "all_star": pd.read_csv(data_dir / "all_star.csv"),
    }
    schema = Schema(
        players=Table(
            playerID=PrimaryKey(),
            pos=Column.CATEGORICAL,
            height=Column.NUMERIC,
            weight=Column.NUMERIC,
            college=Column.CATEGORICAL,
            race=Column.CATEGORICAL,
            birthCity=Column.CATEGORICAL,
            birthState=Column.CATEGORICAL,
            birthCountry=Column.CATEGORICAL,
        ),
        season=Table(
            playerID=ForeignKey(parent="players"),
            year=Column.INTEGER,
            stint=Column.INTEGER,
            tmID=Column.CATEGORICAL,
            lgID=Column.CATEGORICAL,
            GP=Column.INTEGER,
            points=Column.INTEGER,
            GS=Column.INTEGER,
            assists=Column.INTEGER,
            steals=Column.INTEGER,
            minutes=Column.INTEGER,
        ),
        all_star=Table(
            playerID=ForeignKey(parent="players"),
            season_id=Column.INTEGER,
            conference=Column.CATEGORICAL,
            league_id=Column.CATEGORICAL,
            points=Column.INTEGER,
            rebounds=Column.INTEGER,
            assists=Column.INTEGER,
            blocks=Column.INTEGER,
        ),
    )
    data = RelationalData(data=data, schema=schema)
    if data_frac is not None:
        _, data = data.split(ratio=data_frac)

    # Define preprocessor
    preproc = EventPreproc.from_schema(
        schema=schema,
        ord_cols={"season": "year", "all_star": "season_id"},
    ).fit(data=data)

    # Split data
    split_ratio = 0.1
    data_train_valid, data_test = data.split(ratio=split_ratio)
    data_train, data_valid = data_train_valid.split(ratio=split_ratio)

    # Build model
    model = EventModel.build(preproc=preproc, size=model_size)
    model.device = device  # Device to None means it will be set to CUDA if the latter is available, otherwise CPU

    # Train the model
    dataset_train = EventDataset.from_data(data=data_train, preproc=preproc, on_disk=True)
    dataset_valid = EventDataset.from_data(data=data_valid, preproc=preproc)
    trainer = EventTrainer(model=model)
    trainer.train(
        dataset=dataset_train,
        n_epochs=n_epochs,
        n_steps=n_steps,
        memory=memory,
        valid=Validation(
            dataset=dataset_valid,
            early_stop="normal",
            save_best=output_dir / "best.pt",
            tensorboard=output_dir / "tb",
            each=valid_each,
            trigger="step",
        ),
    )

    # Generate synthetic events from scratch
    data_synth = model.generate(
        n_samples=data["players"].shape[0],
        batch_size=512,
    )
    data_synth.to_csv(output_dir / "synth")

    # Continue the time series in the test set, up to N events
    data_synth_continue = model.generate(
        ctx=data_test,
        force_event=True,
        max_n_events=10,
        batch_size=512,
    )
    data_synth_continue.to_csv(output_dir / "synth-continue")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("data_dir", type=Path, help="The directory were to find the 'basket' dataset")
    parser.add_argument("output_dir", type=Path, help="The output directory")
    parser.add_argument("--data-frac", "-d", type=float, help="Fraction of data to use")
    parser.add_argument("--model-size", "-m", type=Size.from_str, default=Size.SMALL, help="Model size")
    parser.add_argument(
        "--n",
        "-n",
        type=int,
        default=1000,
        help="Training epochs (or steps if the --steps flag is used)",
    )
    parser.add_argument("--steps", "-s", action="store_true", help="Use steps instead of epochs")
    parser.add_argument("--valid-each", "-v", type=int, default=200, help="# steps between validations")
    parser.add_argument("--device", "-g", default=None, help="Training device")
    parser.add_argument("--memory", "-y", type=int, default=4096, help="Available memory (MB)")
    args = parser.parse_args()

    example_basket_event(
        data_dir=args.data_dir,
        output_dir=args.output_dir,
        data_frac=args.data_frac,
        model_size=args.model_size,
        n_epochs=None if args.steps else args.n,
        n_steps=args.n if args.steps else None,
        valid_each=args.valid_each,
        device=args.device,
        memory=args.memory,
    )