Skip to content

Single table with LLM-based models

In the following we present an example script using the LLM-based model in the aindo.rdml.synth.llm module to generate from scratch a synthetic dataset consisting of a single table, and subsequently enrich it with some extra columns.

from pathlib import Path

from aindo.rdml.relational import Column
from aindo.rdml.synth.llm import (
    CategoricalColumnStructure,
    LlmColumnCfg,
    LlmTableCfg,
    LlmTabularModel,
    WordSequenceColumnStructure,
)

CKPT_PATH = Path("path/to/ckpt")  # The path to the checkpoint provided by Aindo
MODEL_PATH = None  # Optional path to the LLM model
OUTPUT_DIR = Path("./output")
GEN_MODE = "rejection"
RETRY_ON_FAIL = 10
DEVICE = "cuda"

# Define the model
model = LlmTabularModel.load(
    ckpt_path=CKPT_PATH,
    model_path=MODEL_PATH,
)
model.device = DEVICE

# Define the data configuration
cfg = LlmTableCfg(
    name="Individuals",
    description="Dataset containing data of US adult citizens.",
    columns={
        "name": LlmColumnCfg(
            type=Column.TEXT,
            description="Name of the US Citizen",
            structure=WordSequenceColumnStructure(
                min_n_words=1,
                max_n_words=1,
                min_word_len=0,
                max_word_len=15,
            ),
        ),
        "Gender": LlmColumnCfg(
            type=Column.TEXT,
            description="The gender of the person",
            structure=WordSequenceColumnStructure(
                min_n_words=1,
                max_n_words=1,
            ),
        ),
        "age": LlmColumnCfg(
            type=Column.INTEGER,
            description="The age of the individual",
        ),
    },
)

# Generate data from scratch
data_synth = model.generate(
    cfg=cfg,
    n_samples=50,
    batch_size=40,
    max_tokens=200,
    generation_mode=GEN_MODE,
    retry_on_fail=RETRY_ON_FAIL,
)

# Enrich the generated data
data_synth = model.add_columns(
    data=data_synth,
    context_cfg=cfg,
    new_columns={
        "formatted_sex": LlmColumnCfg(
            type=Column.CATEGORICAL,
            description="Cleaned version of Gender column, either 'M',  'F' or 'other'",
            structure=CategoricalColumnStructure(categories=["M", "F", "other"]),
        ),
        "Job Title": LlmColumnCfg(
            type=Column.TEXT,
            description="The job of the person",
            structure=WordSequenceColumnStructure(
                min_n_words=1,
                max_n_words=5,
            ),
        ),
    },
    batch_size=40,
    generation_mode=GEN_MODE,
    retry_on_fail=RETRY_ON_FAIL,
    max_tokens=10,
)

# Save the generated data
data_synth.to_csv(output_dir=OUTPUT_DIR / "synth")