Single table with LLM-based models
In the following we present an example script using the LLM-based model in the
aindo.rdml.synth.llm
module to generate from scratch a synthetic dataset consisting of a single table,
and subsequently enrich it with some extra columns.
from pathlib import Path
from aindo.rdml.relational import Column
from aindo.rdml.synth.llm import (
CategoricalColumnStructure,
LlmColumnCfg,
LlmTableCfg,
LlmTabularModel,
WordSequenceColumnStructure,
)
CKPT_PATH = Path("path/to/ckpt") # The path to the checkpoint provided by Aindo
MODEL_PATH = None # Optional path to the LLM model
OUTPUT_DIR = Path("./output")
GEN_MODE = "rejection"
RETRY_ON_FAIL = 10
DEVICE = "cuda"
# Define the model
model = LlmTabularModel.load(
ckpt_path=CKPT_PATH,
model_path=MODEL_PATH,
)
model.device = DEVICE
# Define the data configuration
cfg = LlmTableCfg(
name="Individuals",
description="Dataset containing data of US adult citizens.",
columns={
"name": LlmColumnCfg(
type=Column.TEXT,
description="Name of the US Citizen",
structure=WordSequenceColumnStructure(
min_n_words=1,
max_n_words=1,
min_word_len=0,
max_word_len=15,
),
),
"Gender": LlmColumnCfg(
type=Column.TEXT,
description="The gender of the person",
structure=WordSequenceColumnStructure(
min_n_words=1,
max_n_words=1,
),
),
"age": LlmColumnCfg(
type=Column.INTEGER,
description="The age of the individual",
),
},
)
# Generate data from scratch
data_synth = model.generate(
cfg=cfg,
n_samples=50,
batch_size=40,
max_tokens=200,
generation_mode=GEN_MODE,
retry_on_fail=RETRY_ON_FAIL,
)
# Enrich the generated data
data_synth = model.add_columns(
data=data_synth,
context_cfg=cfg,
new_columns={
"formatted_sex": LlmColumnCfg(
type=Column.CATEGORICAL,
description="Cleaned version of Gender column, either 'M', 'F' or 'other'",
structure=CategoricalColumnStructure(categories=["M", "F", "other"]),
),
"Job Title": LlmColumnCfg(
type=Column.TEXT,
description="The job of the person",
structure=WordSequenceColumnStructure(
min_n_words=1,
max_n_words=5,
),
),
},
batch_size=40,
generation_mode=GEN_MODE,
retry_on_fail=RETRY_ON_FAIL,
max_tokens=10,
)
# Save the generated data
data_synth.to_csv(output_dir=OUTPUT_DIR / "synth")