Skip to content

Single table with LLM-based models

In the following we present an example script using the LLM-based model in the aindo.rdml.synth.llm module to generate from scratch a synthetic dataset consisting of a single table, and subsequently enrich it with some extra columns.

import argparse
from pathlib import Path

import torch

from aindo.rdml.relational import Column
from aindo.rdml.synth.llm import (
    CategoricalColumnStructure,
    GenerationMode,
    LlmColumnCfg,
    LlmTableCfg,
    LlmTabularModel,
    WordSequenceColumnStructure,
)


def example_llm(
    output_dir: Path | str,
    ckpt_path: str | Path,
    model_path: str | Path | None,
    mode_gen: GenerationMode | str,
    mode_col: GenerationMode | str,
    retry_on_fail: int,
    device: str | torch.device,
) -> None:
    # Define the model
    model = LlmTabularModel.load(
        ckpt_path=ckpt_path,
        model_path=model_path,
    )
    model.device = device

    # Define the data configuration
    cfg = LlmTableCfg(
        name="Individuals",
        description="Dataset containing data of US adult citizens.",
        columns={
            "name": LlmColumnCfg(
                type=Column.TEXT,
                description="Name of the US Citizen",
                structure=WordSequenceColumnStructure(
                    min_n_words=1,
                    max_n_words=1,
                    min_word_len=0,
                    max_word_len=15,
                ),
            ),
            "Gender": LlmColumnCfg(
                type=Column.TEXT,
                description="The gender of the person",
                structure=WordSequenceColumnStructure(
                    min_n_words=1,
                    max_n_words=1,
                ),
            ),
            "age": LlmColumnCfg(
                type=Column.INTEGER,
                description="The age of the individual",
            ),
        },
    )

    # Generate data from scratch
    data_synth = model.generate(
        cfg=cfg,
        n_samples=50,
        batch_size=40,
        max_tokens=200,
        generation_mode=mode_gen,
        retry_on_fail=retry_on_fail,
    )

    # Enrich the generated data
    data_synth = model.add_columns(
        data=data_synth,
        context_cfg=cfg,
        new_columns={
            "formatted_sex": LlmColumnCfg(
                type=Column.CATEGORICAL,
                description="Cleaned version of Gender column, either 'M',  'F' or 'other'",
                structure=CategoricalColumnStructure(categories=["M", "F", "other"]),
            ),
            "Job Title": LlmColumnCfg(
                type=Column.TEXT,
                description="The job of the person",
                structure=WordSequenceColumnStructure(
                    min_n_words=1,
                    max_n_words=5,
                ),
            ),
        },
        batch_size=40,
        generation_mode=mode_col,
        retry_on_fail=retry_on_fail,
        max_tokens=10,
    )

    # Save the generated data
    data_synth.to_csv(output_dir=output_dir / "synth")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("ckpt", type=Path, help="The path to the checkpoint provided by Aindo")
    parser.add_argument("output_dir", type=Path, help="The output directory")
    parser.add_argument("--model-path", "-m", type=Path, help="The path to the llm model")
    parser.add_argument(
        "--mode-gen",
        "-g",
        type=GenerationMode.from_str,
        default=GenerationMode.REJECTION,
        help="Generation mode during table generation",
    )
    parser.add_argument(
        "--mode-col",
        "-c",
        type=GenerationMode.from_str,
        default=GenerationMode.REJECTION,
        help="Generation mode during column addition",
    )
    parser.add_argument("--retry-on-fail", "-r", type=int, default=10, help="Number of tries for enforced generation")
    parser.add_argument("--device", "-g", default=None, help="Training device")
    args = parser.parse_args()

    example_llm(
        ckpt_path=args.ckpt,
        output_dir=args.output_dir,
        model_path=args.model,
        mode_gen=args.mode_gen,
        mode_col=args.mode_col,
        retry_on_fail=args.retry_on_fail,
        device=args.device,
    )