Skip to content

LLM-based models - Multi-task

We present here a more advanced example script using the aindo.rdml.synth.llm module. In this example we make use once again of the BasketballMen dataset.

First, we fine-tune an LLM to perform several tasks on the original dataset:

  • Full synthetic data
  • Semisynthetic data, with different levels of randomness
  • Prediction task

Then, we use the fine-tuned model to perform a synthetic data generation task and a prediction task on the test set.

from pathlib import Path

import pandas as pd
import torch
from trl import ModelConfig, SFTConfig

from aindo.rdml.eval import report
from aindo.rdml.relational import (
    Column,
    ForeignKey,
    PrimaryKey,
    RelationalData,
    Schema,
    Table,
)
from aindo.rdml.synth.llm import (
    Dataset,
    DatasetElem,
    FtConfig,
    GenConfig,
    RelGenPrompt,
    RelPredict,
    RelSemiSynth,
    RelSynth,
    RelSynthPreproc,
    Task,
    finetune,
)

# Data and output
DATA_DIR = Path("path/to/data/dir")
OUTPUT_DIR = Path("./output")
SPLIT_RATIO = 0.1
SPLIT_SEED = None
# Model settings
MODEL = "Qwen/Qwen3-0.6B-Base"
PROMPT_TEMPLATE_SYNTH = "Data schema: {schema}.\nSynthetic data:\n"
PROMPT_TEMPLATE_SEMISYNTH = "Context: {ctx}, output schema: {out_schema}.\nSemisynthetic data:\n"
PROMPT_TEMPLATE_PRED = "Context: {ctx}, output schema: {out_schema}.\nPrediction:\n"
# Training settings
N_EPOCHS = 3
MAX_STEPS = -1  # if > 0, overrides N_EPOCHS
LEARNING_RATE = 2.0e-4
BATCH_SIZE = 16
MAX_LENGTH = 32768  # max for smaller Qwen3 models
ENFORCE_MAX_LEN = False
EVAL_STEPS = 100
# Generation settings
ENGINE = "vllm"
ENGINE_KWARGS = {  # these are specific for vLLM
    "max_num_seqs": 200,
    "gpu_memory_utilization": 0.75,
}
GENERATE_KWARGS = {  # these are specific for vLLM
    "max_tokens": 32768,  # max for smaller Qwen3 models
}
# Synthetic data settings
N_SAMPLES = 1_000
# Prediction settings
COLS_TGT = {
    "season": ["points", "assists", "steals"],
    "all_star": ["points", "rebounds", "assists", "blocks"],
}
N_PRED = 100


def main() -> None:
    # Load data and define schema
    data = {
        "players": pd.read_csv(DATA_DIR / "players.csv"),
        "season": pd.read_csv(DATA_DIR / "season.csv"),
        "all_star": pd.read_csv(DATA_DIR / "all_star.csv"),
    }
    schema = Schema(
        players=Table(
            playerID=PrimaryKey(),
            pos=Column.CATEGORICAL,
            height=Column.NUMERIC,
            weight=Column.NUMERIC,
            college=Column.CATEGORICAL,
            race=Column.CATEGORICAL,
            birthCity=Column.CATEGORICAL,
            birthState=Column.CATEGORICAL,
            birthCountry=Column.CATEGORICAL,
        ),
        season=Table(
            playerID=ForeignKey(parent="players"),
            year=Column.INTEGER,
            stint=Column.INTEGER,
            tmID=Column.CATEGORICAL,
            lgID=Column.CATEGORICAL,
            GP=Column.INTEGER,
            points=Column.INTEGER,
            GS=Column.INTEGER,
            assists=Column.INTEGER,
            steals=Column.INTEGER,
            minutes=Column.INTEGER,
        ),
        all_star=Table(
            playerID=ForeignKey(parent="players"),
            conference=Column.CATEGORICAL,
            league_id=Column.CATEGORICAL,
            points=Column.INTEGER,
            rebounds=Column.INTEGER,
            assists=Column.INTEGER,
            blocks=Column.INTEGER,
        ),
    )
    data = RelationalData(data=data, schema=schema)

    # Define preprocessor
    preproc = RelSynthPreproc.from_data(data=data)

    # Split data
    data_train, data_test = data.split(ratio=SPLIT_RATIO, rng=SPLIT_SEED)

    # Initialize the tasks
    task_synth = RelSynth(preproc=preproc)

    task_semi_synth_01 = RelSemiSynth(preproc=preproc, p_field=0.1, p_child=0.1)
    task_semi_synth_05 = RelSemiSynth(preproc=preproc, p_field=0.5, p_child=0.5)

    task_pred = RelPredict(preproc=preproc, cols_tgt=COLS_TGT)

    # Create the training dataset
    dataset = Dataset(path=OUTPUT_DIR / "train.jsonl")
    dataset.append(task=task_synth, data=data_train)
    dataset.append(task=task_semi_synth_01, data=data_train)
    dataset.append(task=task_semi_synth_05, data=data_train)
    dataset.append(task=task_pred, data=data_train)

    # Define the prompt: use a function to have different prompts for different tasks
    def prompt_template(x: DatasetElem | RelGenPrompt) -> str:
        match Task.from_str(x.task):
            case Task.SYNTH:
                prompt = PROMPT_TEMPLATE_SYNTH
            case Task.SEMI_SYNTH:
                prompt = PROMPT_TEMPLATE_SEMISYNTH
            case Task.PREDICT:
                prompt = PROMPT_TEMPLATE_PRED
            case _:
                raise ValueError(f"Unsupported task: {x.task}")
        return prompt.format(schema=x.schema, ctx=x.ctx, out_schema=x.out_schema)

    # Finetune the LLM
    train_dir = OUTPUT_DIR / "train"
    cfg_ft = FtConfig(
        dataset_path=dataset.path,
        prompt_template=prompt_template,
        enforce_max_len=ENFORCE_MAX_LEN,
    )
    finetune(
        cfg_ft=cfg_ft,
        cfg_model=ModelConfig(
            model_name_or_path=MODEL,
            dtype="bfloat16",
            use_peft=True,
            lora_r=32,
            lora_alpha=64,
            lora_dropout=0.05,
            lora_target_modules=["all-linear"],
        ),
        cfg_sft=SFTConfig(
            output_dir=str(train_dir),
            optim="adamw_torch_fused",
            num_train_epochs=N_EPOCHS,
            max_steps=MAX_STEPS,
            learning_rate=LEARNING_RATE,
            per_device_train_batch_size=BATCH_SIZE,
            max_length=MAX_LENGTH,
            logging_steps=10,
            eval_strategy="steps",
            eval_on_start=False,
            save_steps=EVAL_STEPS,
            eval_steps=EVAL_STEPS,
            load_best_model_at_end=True,
            save_total_limit=2,
            bf16=True,
            tf32=True,
            report_to="tensorboard",
        ),
    )
    # after finetuning, clear the GPU memory for generation
    torch.cuda.empty_cache()

    # Generate synthetic data
    cfg_gen = GenConfig(
        model=str(train_dir / cfg_ft.best_ckpt),
        prompt_template=prompt_template,
        engine=ENGINE,
        engine_kwargs=ENGINE_KWARGS,
        generate_kwargs=GENERATE_KWARGS,
    )
    synth_dir = OUTPUT_DIR / "synth"
    (data_synth,) = task_synth.generate(
        cfg_gen=cfg_gen,
        n_samples=N_SAMPLES,
        output_dir=synth_dir / "logs",
    )
    torch.cuda.empty_cache()
    data_synth.to_csv(synth_dir)

    # Compute and print the PDF report
    report(
        data_train=data_train,
        data_test=data_test,
        data_synth=data_synth,
        path=synth_dir / "report.pdf",
    )

    # Predict on the test set
    # extract context columns from the test set
    pred_dir = OUTPUT_DIR / "pred"
    ctx = {
        t: df_test.loc[:, [c for c in df_test.columns if c not in COLS_TGT.get(t, ())]]
        for t, df_test in data_test.items()
    }
    pred = task_pred.generate(
        cfg_gen=cfg_gen,
        ctx=ctx,
        n_pred=N_PRED,
        output_dir=pred_dir / "logs",
    )
    torch.cuda.empty_cache()
    for i, p in enumerate(pred):
        p.to_csv(pred_dir / str(i))


if __name__ == "__main__":
    main()