vathos.trainer¶
Trainers¶
-
class
BaseTrainer
(model, loss_fns, optimizer, config, train_subset, test_subset, state_dict=None)[source]¶ BaseTrainer: An Abstract Meta Class for all trainers (GPU, CPU, TPU)
Parameters: - model – the model to be trained, (can be on cpu/gpu)
- loss_fns (Tuple) – (seg_loss, depth_loss)
- optimizer – the optimizer (can be on cpu/gpu)
- config – config in dict format
- train_subset (torch.utils.data.Subset) – train dataset wrapped in a subset containing the indices
- test_subset (torch.utils.data.Subset) – test dataset wrapped in a subset containing the indices
- state_dict (Optional) – the saved state in a dictionary format
-
optimizer_to
(optim, device)[source]¶ moves the optimizer to device
Parameters: - optim – the optimizer
- device – device to which to move to
-
scheduler_to
(sched, device)[source]¶ moves the scheduler to device
Parameters: - sched – the scheduler
- device – device to which to move to
-
class
GPUTrainer
(*args, **kwargs)[source]¶ GPUTrainer: Trains the vathos model on GPU
see
BaseTrainer
for argsExamples
>>> gpu_trainer = GPUTrainer(model, loss_fns, optimizer, cfg, train_subset, test_subset, state_dict=state_dict) >>> gpu_trainer.start_train()
-
start_train
()[source]¶ trains the model for self.epochs times
the model and training state is saved at every epoch
summary is flushed to disk every epoch
-