Skip to content

Fine-tuning an LLM to generate synthetic data

It is possible to use all the pre- and post-processing tools of aindo.rdml.synth.llm with any LLM model of choice, which in turn can be trained with the user's favourite framework.

However, we provide some extra utilities to perform the fine-tuning of Hugging Face (HF) models, optionally using Unsloth to accelerate the training process.

The aindo.rdml.synth.llm package contains two main training utilities:

  • The Dataset class to create the training dataset.
  • The finetune function to perform the finetuning of any HF model.

Build the training dataset

The first step to perform a successful fine-tuning, is to build the training dataset. To assist with that, the aindo.rdml.synth.llm module provides the Dataset class. It is initialized with an optional path to a JSONL file, which can refer to:

  • The location to a previously built dataset.
  • A new dataset, that will be saved at the specified location.

The user can then add training data with the Dataset.append() method, which needs:

  • data: The RelationalData to be used to build the training data.
  • task: A BaseTask object, which is used to process the data and extract the relevant information with the BaseTask.get_dataset() method.

When some data is appended to the dataset, each training example is converted into JSON format by the BaseTask object, then the JSON object is serialized to a string, and finally it is appended to the JSONL file. The JSON representation of a training example is a dictionary corresponding to the DatasetElem data structure. Elements like DatasetElem.ctx and DatasetElem.out, are themselves JSON representation of the context and the target. The DatasetElem.schema and DatasetElem.out_schema fields contain the JSON schema of the complete data and the target respectively. Finally, DatasetElem.description contains an optional description that can be specified at task initialization, and DatasetElem.task contains the name of the Task at hand.

The data in the JSONL file can be used to create the training examples to fine-tune any LLM. The Dataset class also offers two utility methods to load the dataset to be used with HF:

  • Dataset.load(): Load a HF datasets.Dataset. The training examples can be processed as desired with the usual HF functions to bring them in a form suitable to perform the desired fine-tuning.

  • Dataset.load_prompt_completion(): Load a HF datasets.Dataset already in the prompt-completion format. For each training example, the prompt is built from the fields in DatasetElem using the value of the required parameter prompt_template, which can be either a formattable string (with keys the fields of DatasetElem), or a callable that returns the prompt from a DatasetElem object. The completion is taken to be the DatasetElem.out field. The behaviour can be modified specifying a different value for the optional completion_key parameter. Optionally (parameter tokenize), prompt and completion are tokenized, adding the input_ids and completion_mask columns. To do so, the tokenizer must be provided via the tokenizer parameter.

For example, to train an LLM to perform different tasks on a single dataset, one can call the Dataset.append() method with several BaseTask objects.

from transformers import AutoTokenizer

from aindo.rdml.relational import RelationalData
from aindo.rdml.synth.llm import (
    Dataset,
    RelSynth,
    RelSemiSynth,
    RelPredict,
    RelSynthPreproc,
)

# Load the data and split in train-test
data: RelationalData = ...
data_train, data_test = data.split(ratio=0.1)

# Define the preprocessor and the tasks
preproc = RelSynthPreproc.from_data(data=data)

task_synth = RelSynth(preproc=preproc)

task_semi_synth_01 = RelSemiSynth(preproc=preproc, p_field=0.1, p_child=0.1)
task_semi_synth_05 = RelSemiSynth(preproc=preproc, p_field=0.5, p_child=0.5)

task_pred = RelPredict(preproc=preproc, cols_tgt={"players": ["pos"]})

# Initialize the dataset, and append the training data,
# processed for the different tasks
dataset = Dataset(path="path/to/train.jsonl")
dataset.append(task=task_synth, data=data_train)
dataset.append(task=task_semi_synth_01, data=data_train)
dataset.append(task=task_semi_synth_05, data=data_train)
dataset.append(task=task_pred, data=data_train)

# Load the HF dataset in prompt-completion form
dataset_hf = dataset.load_prompt_completion(
    prompt_template="Ctx: {ctx}.\Synthetic data: {out}",
    # we could use a function of `DatasetElem` to have different prompts for different tasks
    tokenizer=AutoTokenizer.from_pretrained("the-model-name"),
    # optional, if provided the dataset is tokenized
)

Fine-tune the LLM

The function finetune can be used to specialize an LLM on a specific dataset. It uses the HF trl.SFTTrainer to perform the fine-tuning with the provided Dataset. The function takes care of loading the HF model, loading the dataset in prompt-completion form, instantiating the trainer, and calling the trl.SFTTrainer.train() method.

The finetune function takes the following parameters:

  • cfg_ft: A FtConfig with the fine-tuning configuration.
  • cfg_model: A trl.ModelConfig with the model configuration.
  • cfg_sft: The trl.SFTConfig with the training configuration of the trl.SFTTrainer.
  • unsloth: Whether to use Unsloth to speed up the training.

The user must provide the model to be fine-tuned with the model_name_or_path option of the trl.ModelConfig.

Moreover, in the FtConfig it is necessary to specify:

There are also a few optional parameters of FtConfig:

  • enforce_max_len: Whether to raise an error if any tokenized training example exceeds the maximum number of tokens allowed by the model. If False, longer sequences are truncated during training.
  • valid_frac: The fraction of training data to be used as validation set.
  • early_stop: The patience to be used in the transformers.EarlyStoppingCallback.
  • dump_n_examples: If greater than zero, the number of training examples to save on disk for inspection.
  • best_ckpt: The name of a symlink in the training output folder linking to the best checkpoint, namely the one with the lowest value of the validation loss.

The training can be greatly customized, by specifying the desired optional parameters in trl.ModelConfig and trl.SFTConfig. With the former, it is possible for example to set up a LoRA training. With the latter, one can specify the optimizer, the number of epochs or steps, the learning rate, the batch size and so on. Since they are objects of the trl library, we refer the user to their original documentation for more information. An example of a possible training set up can be found in example section.

To fine-tune a model on the training dataset built in the previous example, it is enough to provide the path to the Dataset.

from trl import ModelConfig, SFTConfig

from aindo.rdml.relational import RelationalData
from aindo.rdml.synth.llm import (
    Dataset,
    FtConfig,
    RelSynth,
    RelSemiSynth,
    RelPredict,
    RelSynthPreproc,
    finetune,
)

# Load the data and split in train-test
data: RelationalData = ...
data_train, data_test = data.split(ratio=0.1)

# Define the preprocessor and the tasks
preproc = RelSynthPreproc.from_data(data=data)

task_synth = RelSynth(preproc=preproc)

task_semi_synth_01 = RelSemiSynth(preproc=preproc, p_field=0.1, p_child=0.1)
task_semi_synth_05 = RelSemiSynth(preproc=preproc, p_field=0.5, p_child=0.5)

task_pred = RelPredict(preproc=preproc, cols_tgt={"players": ["pos"]})

# Initialize the dataset, and append the training data,
# processed for the different tasks
dataset = Dataset(path="path/to/train.jsonl")
dataset.append(task=task_synth, data=data_train)
dataset.append(task=task_semi_synth_01, data=data_train)
dataset.append(task=task_semi_synth_05, data=data_train)
dataset.append(task=task_pred, data=data_train)

# Fine-tune
finetune(
    cfg_ft=FtConfig(
        dataset_path=dataset.path,
        prompt_template="Ctx: {ctx}.\Synthetic data: {out}",
    ),
    cfg_model=ModelConfig(
        model_name_or_path="your-favourite-model",
        dtype="bfloat16",
        use_peft=True,
        ...  # example: lora config
    ),
    cfg_sft=SFTConfig(
        optim="adamw_torch_fused",
        num_train_epochs=10,
        learning_rate=2.0e-4,
        per_device_train_batch_size=16,
        ...
    ),
)