Skip to content

Saving and loading

Saving and loading

Trained models can be saved using the TabularModel.save() and TextModel.save().save methods by providing the desired path for the checkpoint. These models can later be loaded using the TabularModel.load() and TextModel.load() class methods, where the correct path to the checkpoint must be specified.

Similarly, trainers can be saved and loaded using the TabularTrainer.save() and TextTrainer.save() class methods, and the TabularTrainer.load() and TextTrainer.load() class methods, respectively. When saving a trainer, there is no need to explicitly save the associated model, as it will be automatically saved. After loading the trainer, the model can be accessed via the TabularTrainer.model or TextTrainer.model attribute.

We recommend saving the trainer if you plan to resume training later. If the model is only intended for inference or generating synthetic data, it is sufficient to save just the model.

The same TabularTrainer.load() and TextTrainer.load() class methods can also be used to load a checkpoint automatically generated during training, provided the Validation.save_best option was enabled.

Below is an example of how to save a (possibly trained) model:

from aindo.rdml.synth import TabularModel

model = TabularModel.build(preproc=..., size=...)

# Train the tabular model
# as shown above
...

model.save(path="path/to/ckpt")

and how to load it (possibly in a different session) to generate synthetic data:

from aindo.rdml.synth import TabularModel

model = TabularModel.load(path="path/to/ckpt")

data_synth = model.generate(
    n_samples=1_000,
    batch_size=256,
)