updown.utils.checkpointing¶
-
class
updown.utils.checkpointing.
CheckpointManager
(models: Union[torch.nn.modules.module.Module, Dict[str, torch.nn.modules.module.Module]], optimizer: Type[torch.optim.optimizer.Optimizer], serialization_dir: str, mode: str = 'max', filename_prefix: str = 'checkpoint')[source]¶ Bases:
object
A
CheckpointManager
periodically serializes models and optimizer as .pth files during training, and keeps track of best performing checkpoint based on an observed metric.It saves state dicts of models and optimizer as
.pth
files in a specified directory. This class closely follows the API of PyTorch optimizers and learning rate schedulers.- Parameters
- models: Dict[str, torch.nn.Module]
Models which need to be serialized as a checkpoint.
- optimizer: torch.optim.Optimizer
Optimizer which needs to be serialized as a checkpoint.
- serialization_dir: str
Path to an empty or non-existent directory to save checkpoints.
- mode: str, optional (default=”max”)
One of
min
,max
. Inmin
mode, best checkpoint will be recorded when metric hits a lower value; in max mode it will be recorded when metric hits a higher value.- filename_prefix: str, optional (default=”checkpoint”)
Prefix of the to-be-saved checkpoint files.
Notes
For
DataParallel
objects,.module.state_dict()
is called instead of.state_dict()
.Examples
>>> model = torch.nn.Linear(10, 2) >>> optimizer = torch.optim.SGD(model.parameters()) >>> ckpt_manager = CheckpointManager({"model": model}, optimizer, "/tmp/ckpt", mode="min") >>> num_epochs = 20 >>> for epoch in range(num_epochs): ... train(model) ... val_loss = validate(model) ... ckpt_manager.step(val_loss, epoch)