LLM-based models - Multi-task
We present here a more advanced example script using the aindo.rdml.synth.llm module.
In this example we make use once again of the BasketballMen dataset.
First, we fine-tune an LLM to perform several tasks on the original dataset:
- Full synthetic data
- Semisynthetic data, with different levels of randomness
- Prediction task
Then, we use the fine-tuned model to perform a synthetic data generation task and a prediction task on the test set.
from pathlib import Path
import pandas as pd
import torch
from trl import ModelConfig, SFTConfig
from aindo.rdml.eval import report
from aindo.rdml.relational import (
Column,
ForeignKey,
PrimaryKey,
RelationalData,
Schema,
Table,
)
from aindo.rdml.synth.llm import (
Dataset,
DatasetElem,
FtConfig,
GenConfig,
RelGenPrompt,
RelPredict,
RelSemiSynth,
RelSynth,
RelSynthPreproc,
Task,
finetune,
)
# Data and output
DATA_DIR = Path("path/to/data/dir")
OUTPUT_DIR = Path("./output")
SPLIT_RATIO = 0.1
SPLIT_SEED = None
# Model settings
MODEL = "Qwen/Qwen3-0.6B-Base"
PROMPT_TEMPLATE_SYNTH = "Data schema: {schema}.\nSynthetic data:\n"
PROMPT_TEMPLATE_SEMISYNTH = "Context: {ctx}, output schema: {out_schema}.\nSemisynthetic data:\n"
PROMPT_TEMPLATE_PRED = "Context: {ctx}, output schema: {out_schema}.\nPrediction:\n"
# Training settings
N_EPOCHS = 3
MAX_STEPS = -1 # if > 0, overrides N_EPOCHS
LEARNING_RATE = 2.0e-4
BATCH_SIZE = 16
MAX_LENGTH = 32768 # max for smaller Qwen3 models
ENFORCE_MAX_LEN = False
EVAL_STEPS = 100
# Generation settings
ENGINE = "vllm"
ENGINE_KWARGS = { # these are specific for vLLM
"max_num_seqs": 200,
"gpu_memory_utilization": 0.75,
}
GENERATE_KWARGS = { # these are specific for vLLM
"max_tokens": 32768, # max for smaller Qwen3 models
}
# Synthetic data settings
N_SAMPLES = 1_000
# Prediction settings
COLS_TGT = {
"season": ["points", "assists", "steals"],
"all_star": ["points", "rebounds", "assists", "blocks"],
}
N_PRED = 100
def main() -> None:
# Load data and define schema
data = {
"players": pd.read_csv(DATA_DIR / "players.csv"),
"season": pd.read_csv(DATA_DIR / "season.csv"),
"all_star": pd.read_csv(DATA_DIR / "all_star.csv"),
}
schema = Schema(
players=Table(
playerID=PrimaryKey(),
pos=Column.CATEGORICAL,
height=Column.NUMERIC,
weight=Column.NUMERIC,
college=Column.CATEGORICAL,
race=Column.CATEGORICAL,
birthCity=Column.CATEGORICAL,
birthState=Column.CATEGORICAL,
birthCountry=Column.CATEGORICAL,
),
season=Table(
playerID=ForeignKey(parent="players"),
year=Column.INTEGER,
stint=Column.INTEGER,
tmID=Column.CATEGORICAL,
lgID=Column.CATEGORICAL,
GP=Column.INTEGER,
points=Column.INTEGER,
GS=Column.INTEGER,
assists=Column.INTEGER,
steals=Column.INTEGER,
minutes=Column.INTEGER,
),
all_star=Table(
playerID=ForeignKey(parent="players"),
conference=Column.CATEGORICAL,
league_id=Column.CATEGORICAL,
points=Column.INTEGER,
rebounds=Column.INTEGER,
assists=Column.INTEGER,
blocks=Column.INTEGER,
),
)
data = RelationalData(data=data, schema=schema)
# Define preprocessor
preproc = RelSynthPreproc.from_data(data=data)
# Split data
data_train, data_test = data.split(ratio=SPLIT_RATIO, rng=SPLIT_SEED)
# Initialize the tasks
task_synth = RelSynth(preproc=preproc)
task_semi_synth_01 = RelSemiSynth(preproc=preproc, p_field=0.1, p_child=0.1)
task_semi_synth_05 = RelSemiSynth(preproc=preproc, p_field=0.5, p_child=0.5)
task_pred = RelPredict(preproc=preproc, cols_tgt=COLS_TGT)
# Create the training dataset
dataset = Dataset(path=OUTPUT_DIR / "train.jsonl")
dataset.append(task=task_synth, data=data_train)
dataset.append(task=task_semi_synth_01, data=data_train)
dataset.append(task=task_semi_synth_05, data=data_train)
dataset.append(task=task_pred, data=data_train)
# Define the prompt: use a function to have different prompts for different tasks
def prompt_template(x: DatasetElem | RelGenPrompt) -> str:
match Task.from_str(x.task):
case Task.SYNTH:
prompt = PROMPT_TEMPLATE_SYNTH
case Task.SEMI_SYNTH:
prompt = PROMPT_TEMPLATE_SEMISYNTH
case Task.PREDICT:
prompt = PROMPT_TEMPLATE_PRED
case _:
raise ValueError(f"Unsupported task: {x.task}")
return prompt.format(schema=x.schema, ctx=x.ctx, out_schema=x.out_schema)
# Finetune the LLM
train_dir = OUTPUT_DIR / "train"
cfg_ft = FtConfig(
dataset_path=dataset.path,
prompt_template=prompt_template,
enforce_max_len=ENFORCE_MAX_LEN,
)
finetune(
cfg_ft=cfg_ft,
cfg_model=ModelConfig(
model_name_or_path=MODEL,
dtype="bfloat16",
use_peft=True,
lora_r=32,
lora_alpha=64,
lora_dropout=0.05,
lora_target_modules=["all-linear"],
),
cfg_sft=SFTConfig(
output_dir=str(train_dir),
optim="adamw_torch_fused",
num_train_epochs=N_EPOCHS,
max_steps=MAX_STEPS,
learning_rate=LEARNING_RATE,
per_device_train_batch_size=BATCH_SIZE,
max_length=MAX_LENGTH,
logging_steps=10,
eval_strategy="steps",
eval_on_start=False,
save_steps=EVAL_STEPS,
eval_steps=EVAL_STEPS,
load_best_model_at_end=True,
save_total_limit=2,
bf16=True,
tf32=True,
report_to="tensorboard",
),
)
# after finetuning, clear the GPU memory for generation
torch.cuda.empty_cache()
# Generate synthetic data
cfg_gen = GenConfig(
model=str(train_dir / cfg_ft.best_ckpt),
prompt_template=prompt_template,
engine=ENGINE,
engine_kwargs=ENGINE_KWARGS,
generate_kwargs=GENERATE_KWARGS,
)
synth_dir = OUTPUT_DIR / "synth"
(data_synth,) = task_synth.generate(
cfg_gen=cfg_gen,
n_samples=N_SAMPLES,
output_dir=synth_dir / "logs",
)
torch.cuda.empty_cache()
data_synth.to_csv(synth_dir)
# Compute and print the PDF report
report(
data_train=data_train,
data_test=data_test,
data_synth=data_synth,
path=synth_dir / "report.pdf",
)
# Predict on the test set
# extract context columns from the test set
pred_dir = OUTPUT_DIR / "pred"
ctx = {
t: df_test.loc[:, [c for c in df_test.columns if c not in COLS_TGT.get(t, ())]]
for t, df_test in data_test.items()
}
pred = task_pred.generate(
cfg_gen=cfg_gen,
ctx=ctx,
n_pred=N_PRED,
output_dir=pred_dir / "logs",
)
torch.cuda.empty_cache()
for i, p in enumerate(pred):
p.to_csv(pred_dir / str(i))
if __name__ == "__main__":
main()