Skip to content

LLM based models - Single table - Adult dataset

In the following we present an example script using the LLM-based model in the aindo.rdml.synth.llm module. In this example we use the simple UCI Adult dataset.

We fine-tune an LLM to on the task of generating synthetic data, and then we use it to create a synthetic version of the original dataset.

from pathlib import Path

import pandas as pd
import torch
from trl import ModelConfig, SFTConfig

from aindo.rdml.eval import report
from aindo.rdml.relational import (
    Column,
    RelationalData,
    Schema,
    Table,
)
from aindo.rdml.synth.llm import (
    Dataset,
    FtConfig,
    GenConfig,
    RelSynth,
    RelSynthPreproc,
    finetune,
)

# Data and output
DATA_DIR = Path("path/to/data/dir")
OUTPUT_DIR = Path("./output")
SPLIT_RATIO = 0.1
SPLIT_SEED = None
# Model settings
MODEL = "Qwen/Qwen3-0.6B-Base"
PROMPT_TEMPLATE = "Data schema:\n{schema}.\nSynthetic data:\n"
# Training settings
N_EPOCHS = 3
MAX_STEPS = -1  # if > 0, overrides N_EPOCHS
LEARNING_RATE = 2.0e-4
BATCH_SIZE = 16
MAX_LENGTH = 32768  # max for smaller Qwen3 models
ENFORCE_MAX_LEN = False
EVAL_STEPS = 100
# Generation settings
ENGINE = "vllm"
ENGINE_KWARGS = {  # these are specific for vLLM
    "max_num_seqs": 200,
    "gpu_memory_utilization": 0.75,
}
GENERATE_KWARGS = {  # these are specific for vLLM
    "max_tokens": 32768,  # max for smaller Qwen3 models
}
# Synthetic data settings
N_SAMPLES = 1_000


def main() -> None:
    # Load data and define schema
    schema = Schema(
        adult=Table(
            columns={
                "age": Column.INTEGER,
                "workclass": Column.CATEGORICAL,
                "fnlwgt": Column.INTEGER,
                "education": Column.CATEGORICAL,
                "education-num": Column.CATEGORICAL,
                "marital-status": Column.CATEGORICAL,
                "occupation": Column.CATEGORICAL,
                "relationship": Column.CATEGORICAL,
                "race": Column.CATEGORICAL,
                "sex": Column.CATEGORICAL,
                "capital-gain": Column.INTEGER,
                "capital-loss": Column.INTEGER,
                "hours-per-week": Column.INTEGER,
                "native-country": Column.CATEGORICAL,
                "y": Column.CATEGORICAL,
            }
        ),
    )
    data = {
        "adult": pd.read_csv(
            DATA_DIR / "adult.data",
            names=list(schema.tables["adult"].columns),
        ),
    }
    data = RelationalData(data=data, schema=schema)

    # Define preprocessor
    preproc = RelSynthPreproc.from_data(data=data)

    # Split data
    data_train, data_test = data.split(ratio=SPLIT_RATIO, rng=SPLIT_SEED)

    # Initialize the task
    task = RelSynth(preproc=preproc)

    # Create the training dataset
    dataset = Dataset(path=OUTPUT_DIR / "train.jsonl")
    dataset.append(task=task, data=data_train)

    # Finetune the LLM
    train_dir = OUTPUT_DIR / "train"
    cfg_ft = FtConfig(
        dataset_path=dataset.path,
        prompt_template=PROMPT_TEMPLATE,
        enforce_max_len=ENFORCE_MAX_LEN,
    )
    finetune(
        cfg_ft=cfg_ft,
        cfg_model=ModelConfig(
            model_name_or_path=MODEL,
            dtype="bfloat16",
            use_peft=True,
            lora_r=32,
            lora_alpha=64,
            lora_dropout=0.05,
            lora_target_modules=["all-linear"],
        ),
        cfg_sft=SFTConfig(
            output_dir=str(train_dir),
            optim="adamw_torch_fused",
            num_train_epochs=N_EPOCHS,
            max_steps=MAX_STEPS,
            learning_rate=LEARNING_RATE,
            per_device_train_batch_size=BATCH_SIZE,
            max_length=MAX_LENGTH,
            logging_steps=10,
            eval_strategy="steps",
            eval_on_start=False,
            save_steps=EVAL_STEPS,
            eval_steps=EVAL_STEPS,
            load_best_model_at_end=True,
            save_total_limit=2,
            bf16=True,
            tf32=True,
            report_to="tensorboard",
        ),
    )
    # after finetuning, clear the GPU memory for generation
    torch.cuda.empty_cache()

    # Generate synthetic data
    cfg_gen = GenConfig(
        model=str(train_dir / cfg_ft.best_ckpt),
        prompt_template=PROMPT_TEMPLATE,
        engine=ENGINE,
        engine_kwargs=ENGINE_KWARGS,
        generate_kwargs=GENERATE_KWARGS,
    )
    synth_dir = OUTPUT_DIR / "synth"
    (data_synth,) = task.generate(
        cfg_gen=cfg_gen,
        n_samples=N_SAMPLES,
        output_dir=synth_dir / "logs",
    )
    torch.cuda.empty_cache()
    data_synth.to_csv(synth_dir)

    # Compute and print the PDF report
    report(
        data_train=data_train,
        data_test=data_test,
        data_synth=data_synth,
        path=synth_dir / "report.pdf",
    )


if __name__ == "__main__":
    main()