Fine-tuning an LLM to generate synthetic data
It is possible to use all the pre- and post-processing tools of aindo.rdml.synth.llm
with any LLM model of choice, which in turn can be trained with the user's favourite framework.
However, we provide some extra utilities to perform the fine-tuning of Hugging Face (HF) models, optionally using Unsloth to accelerate the training process.
The aindo.rdml.synth.llm package contains two main training utilities:
- The
Datasetclass to create the training dataset. - The
finetunefunction to perform the finetuning of any HF model.
Build the training dataset
The first step to perform a successful fine-tuning, is to build the training dataset.
To assist with that, the aindo.rdml.synth.llm module provides the
Dataset class.
It is initialized with an optional path to a JSONL file, which can refer to:
- The location to a previously built dataset.
- A new dataset, that will be saved at the specified location.
The user can then add training data with the Dataset.append() method,
which needs:
data: TheRelationalDatato be used to build the training data.task: ABaseTaskobject, which is used to process the data and extract the relevant information with theBaseTask.get_dataset()method.
When some data is appended to the dataset, each training example is converted into JSON format by the
BaseTask object, then the JSON object is serialized to a string,
and finally it is appended to the JSONL file.
The JSON representation of a training example is a dictionary corresponding to the
DatasetElem data structure.
Elements like DatasetElem.ctx and
DatasetElem.out, are themselves JSON representation of the context and the target.
The DatasetElem.schema and
DatasetElem.out_schema fields contain the JSON schema of the complete data
and the target respectively.
Finally, DatasetElem.description contains an optional description
that can be specified at task initialization, and DatasetElem.task
contains the name of the Task at hand.
The data in the JSONL file can be used to create the training examples to fine-tune any LLM.
The Dataset class also offers two utility methods
to load the dataset to be used with HF:
-
Dataset.load(): Load a HF datasets.Dataset. The training examples can be processed as desired with the usual HF functions to bring them in a form suitable to perform the desired fine-tuning. -
Dataset.load_prompt_completion(): Load a HF datasets.Dataset already in the prompt-completion format. For each training example, the prompt is built from the fields inDatasetElemusing the value of the required parameterprompt_template, which can be either a formattable string (with keys the fields ofDatasetElem), or a callable that returns the prompt from aDatasetElemobject. The completion is taken to be theDatasetElem.outfield. The behaviour can be modified specifying a different value for the optionalcompletion_keyparameter. Optionally (parametertokenize), prompt and completion are tokenized, adding theinput_idsandcompletion_maskcolumns. To do so, the tokenizer must be provided via thetokenizerparameter.
For example, to train an LLM to perform different tasks on a single dataset, one can call the
Dataset.append() method with several
BaseTask objects.
from transformers import AutoTokenizer
from aindo.rdml.relational import RelationalData
from aindo.rdml.synth.llm import (
Dataset,
RelSynth,
RelSemiSynth,
RelPredict,
RelSynthPreproc,
)
# Load the data and split in train-test
data: RelationalData = ...
data_train, data_test = data.split(ratio=0.1)
# Define the preprocessor and the tasks
preproc = RelSynthPreproc.from_data(data=data)
task_synth = RelSynth(preproc=preproc)
task_semi_synth_01 = RelSemiSynth(preproc=preproc, p_field=0.1, p_child=0.1)
task_semi_synth_05 = RelSemiSynth(preproc=preproc, p_field=0.5, p_child=0.5)
task_pred = RelPredict(preproc=preproc, cols_tgt={"players": ["pos"]})
# Initialize the dataset, and append the training data,
# processed for the different tasks
dataset = Dataset(path="path/to/train.jsonl")
dataset.append(task=task_synth, data=data_train)
dataset.append(task=task_semi_synth_01, data=data_train)
dataset.append(task=task_semi_synth_05, data=data_train)
dataset.append(task=task_pred, data=data_train)
# Load the HF dataset in prompt-completion form
dataset_hf = dataset.load_prompt_completion(
prompt_template="Ctx: {ctx}.\Synthetic data: {out}",
# we could use a function of `DatasetElem` to have different prompts for different tasks
tokenizer=AutoTokenizer.from_pretrained("the-model-name"),
# optional, if provided the dataset is tokenized
)
Fine-tune the LLM
The function finetune can be used to specialize an LLM on a specific dataset.
It uses the HF trl.SFTTrainer to perform the fine-tuning with the provided
Dataset.
The function takes care of loading the HF model, loading the dataset in prompt-completion form,
instantiating the trainer, and calling the trl.SFTTrainer.train() method.
The finetune function takes the following parameters:
cfg_ft: AFtConfigwith the fine-tuning configuration.cfg_model: Atrl.ModelConfigwith the model configuration.cfg_sft: Thetrl.SFTConfigwith the training configuration of thetrl.SFTTrainer.unsloth: Whether to use Unsloth to speed up the training.
The user must provide the model to be fine-tuned with the model_name_or_path option of the trl.ModelConfig.
Moreover, in the FtConfig it is necessary to specify:
dataset_path: The path to aDatasetJSONL file.prompt_template: The template to be used inDataset.load_prompt_completion()to build the prompts for each training example.
There are also a few optional parameters of FtConfig:
enforce_max_len: Whether to raise an error if any tokenized training example exceeds the maximum number of tokens allowed by the model. IfFalse, longer sequences are truncated during training.valid_frac: The fraction of training data to be used as validation set.early_stop: The patience to be used in thetransformers.EarlyStoppingCallback.dump_n_examples: If greater than zero, the number of training examples to save on disk for inspection.best_ckpt: The name of a symlink in the training output folder linking to the best checkpoint, namely the one with the lowest value of the validation loss.
The training can be greatly customized, by specifying the desired optional parameters in trl.ModelConfig and
trl.SFTConfig.
With the former, it is possible for example to set up a LoRA training.
With the latter, one can specify the optimizer, the number of epochs or steps, the learning rate,
the batch size and so on.
Since they are objects of the trl library, we refer the user
to their original documentation for more information.
An example of a possible training set up can be found in example section.
To fine-tune a model on the training dataset built in the previous example, it is enough to provide the path
to the Dataset.
from trl import ModelConfig, SFTConfig
from aindo.rdml.relational import RelationalData
from aindo.rdml.synth.llm import (
Dataset,
FtConfig,
RelSynth,
RelSemiSynth,
RelPredict,
RelSynthPreproc,
finetune,
)
# Load the data and split in train-test
data: RelationalData = ...
data_train, data_test = data.split(ratio=0.1)
# Define the preprocessor and the tasks
preproc = RelSynthPreproc.from_data(data=data)
task_synth = RelSynth(preproc=preproc)
task_semi_synth_01 = RelSemiSynth(preproc=preproc, p_field=0.1, p_child=0.1)
task_semi_synth_05 = RelSemiSynth(preproc=preproc, p_field=0.5, p_child=0.5)
task_pred = RelPredict(preproc=preproc, cols_tgt={"players": ["pos"]})
# Initialize the dataset, and append the training data,
# processed for the different tasks
dataset = Dataset(path="path/to/train.jsonl")
dataset.append(task=task_synth, data=data_train)
dataset.append(task=task_semi_synth_01, data=data_train)
dataset.append(task=task_semi_synth_05, data=data_train)
dataset.append(task=task_pred, data=data_train)
# Fine-tune
finetune(
cfg_ft=FtConfig(
dataset_path=dataset.path,
prompt_template="Ctx: {ctx}.\Synthetic data: {out}",
),
cfg_model=ModelConfig(
model_name_or_path="your-favourite-model",
dtype="bfloat16",
use_peft=True,
... # example: lora config
),
cfg_sft=SFTConfig(
optim="adamw_torch_fused",
num_train_epochs=10,
learning_rate=2.0e-4,
per_device_train_batch_size=16,
...
),
)