Saving and loading
Saving and loading
Trained models can be saved using the TabularModel.save()
and
TextModel.save()
.save methods by providing the desired path for the checkpoint.
These models can later be loaded using the TabularModel.load()
and TextModel.load()
class methods,
where the correct path to the checkpoint must be specified.
Similarly, trainers can be saved and loaded using the TabularTrainer.save()
and TextTrainer.save()
class methods,
and the TabularTrainer.load()
and TextTrainer.load()
class methods, respectively.
When saving a trainer, there is no need to explicitly save the associated model, as it will be automatically saved.
After loading the trainer, the model can be accessed via the TabularTrainer.model
or TextTrainer.model
attribute.
We recommend saving the trainer if you plan to resume training later. If the model is only intended for inference or generating synthetic data, it is sufficient to save just the model.
The same TabularTrainer.load()
and TextTrainer.load()
class methods can also be used to load a checkpoint
automatically generated during training, provided the Validation.save_best
option was enabled.
Below is an example of how to save a (possibly trained) model:
from aindo.rdml.synth import TabularModel
model = TabularModel.build(preproc=..., size=...)
# Train the tabular model
# as shown above
...
model.save(path="path/to/ckpt")
and how to load it (possibly in a different session) to generate synthetic data:
from aindo.rdml.synth import TabularModel
model = TabularModel.load(path="path/to/ckpt")
data_synth = model.generate(
n_samples=1_000,
batch_size=256,
)