Single table with LLM-based models
In the following we present an example script using the LLM-based model in the
aindo.rdml.synth.llm
module to generate from scratch a synthetic dataset consisting of a single table,
and subsequently enrich it with some extra columns.
import argparse
from pathlib import Path
import torch
from aindo.rdml.relational import Column
from aindo.rdml.synth.llm import (
CategoricalColumnStructure,
GenerationMode,
LlmColumnCfg,
LlmTableCfg,
LlmTabularModel,
WordSequenceColumnStructure,
)
def example_llm(
output_dir: Path | str,
ckpt_path: str | Path,
model_path: str | Path | None,
mode_gen: GenerationMode | str,
mode_col: GenerationMode | str,
retry_on_fail: int,
device: str | torch.device,
) -> None:
# Define the model
model = LlmTabularModel.load(
ckpt_path=ckpt_path,
model_path=model_path,
)
model.device = device
# Define the data configuration
cfg = LlmTableCfg(
name="Individuals",
description="Dataset containing data of US adult citizens.",
columns={
"name": LlmColumnCfg(
type=Column.TEXT,
description="Name of the US Citizen",
structure=WordSequenceColumnStructure(
min_n_words=1,
max_n_words=1,
min_word_len=0,
max_word_len=15,
),
),
"Gender": LlmColumnCfg(
type=Column.TEXT,
description="The gender of the person",
structure=WordSequenceColumnStructure(
min_n_words=1,
max_n_words=1,
),
),
"age": LlmColumnCfg(
type=Column.INTEGER,
description="The age of the individual",
),
},
)
# Generate data from scratch
data_synth = model.generate(
cfg=cfg,
n_samples=50,
batch_size=40,
max_tokens=200,
generation_mode=mode_gen,
retry_on_fail=retry_on_fail,
)
# Enrich the generated data
data_synth = model.add_columns(
data=data_synth,
context_cfg=cfg,
new_columns={
"formatted_sex": LlmColumnCfg(
type=Column.CATEGORICAL,
description="Cleaned version of Gender column, either 'M', 'F' or 'other'",
structure=CategoricalColumnStructure(categories=["M", "F", "other"]),
),
"Job Title": LlmColumnCfg(
type=Column.TEXT,
description="The job of the person",
structure=WordSequenceColumnStructure(
min_n_words=1,
max_n_words=5,
),
),
},
batch_size=40,
generation_mode=mode_col,
retry_on_fail=retry_on_fail,
max_tokens=10,
)
# Save the generated data
data_synth.to_csv(output_dir=output_dir / "synth")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("ckpt", type=Path, help="The path to the checkpoint provided by Aindo")
parser.add_argument("output_dir", type=Path, help="The output directory")
parser.add_argument("--model-path", "-m", type=Path, help="The path to the llm model")
parser.add_argument(
"--mode-gen",
"-g",
type=GenerationMode.from_str,
default=GenerationMode.REJECTION,
help="Generation mode during table generation",
)
parser.add_argument(
"--mode-col",
"-c",
type=GenerationMode.from_str,
default=GenerationMode.REJECTION,
help="Generation mode during column addition",
)
parser.add_argument("--retry-on-fail", "-r", type=int, default=10, help="Number of tries for enforced generation")
parser.add_argument("--device", "-g", default=None, help="Training device")
args = parser.parse_args()
example_llm(
ckpt_path=args.ckpt,
output_dir=args.output_dir,
model_path=args.model,
mode_gen=args.mode_gen,
mode_col=args.mode_col,
retry_on_fail=args.retry_on_fail,
device=args.device,
)