Source code for vathos.trainer.base_trainer

import abc
import os
from pathlib import Path
from datetime import datetime

from vathos.utils import get_instance_v2, setup_logger

import torch
from torch.utils.tensorboard import SummaryWriter
import torch.utils as utils
import torch.optim as optim

logger = setup_logger(__name__)


[docs]def optimizer_to(optim, device): r"""moves the optimizer to device Args: optim: the optimizer device: device to which to move to """ for param in optim.state.values(): # Not sure there are any global tensors in the state dict if isinstance(param, torch.Tensor): param.data = param.data.to(device) if param._grad is not None: param._grad.data = param._grad.data.to(device) elif isinstance(param, dict): for subparam in param.values(): if isinstance(subparam, torch.Tensor): subparam.data = subparam.data.to(device) if subparam._grad is not None: subparam._grad.data = subparam._grad.data.to(device)
[docs]def scheduler_to(sched, device): r"""moves the scheduler to device Args: sched: the scheduler device: device to which to move to """ for param in sched.__dict__.values(): if isinstance(param, torch.Tensor): param.data = param.data.to(device) if param._grad is not None: param._grad.data = param._grad.data.to(device)
[docs]class BaseTrainer(metaclass=abc.ABCMeta): r"""BaseTrainer: An Abstract Meta Class for all trainers (GPU, CPU, TPU) Args: model: the model to be trained, (can be on cpu/gpu) loss_fns (Tuple): (seg_loss, depth_loss) optimizer: the optimizer (can be on cpu/gpu) config: config in dict format train_subset(torch.utils.data.Subset): train dataset wrapped in a subset containing the indices test_subset(torch.utils.data.Subset): test dataset wrapped in a subset containing the indices state_dict(Optional): the saved state in a dictionary format """ def __init__(self, model, loss_fns, optimizer, config, train_subset, test_subset, state_dict=None): super(BaseTrainer, self).__init__() cfg = config self.model = model self.optimizer = optimizer self.seg_loss, self.depth_loss = loss_fns self.train_subset = train_subset self.test_subset = test_subset self.config = config self.start_epoch = 0 current_time = datetime.now().strftime('%b%d_%H-%M-%S') self.writer = SummaryWriter( f"{config['log_dir']}/{config['name']}.{current_time}") self.comb_loss = lambda l1, l2: l1 + 2*l2 self.train_loader = utils.data.DataLoader(self.train_subset) self.train_loader = get_instance_v2( utils.data, 'DataLoader', self.train_subset, **cfg['dataset']['loader_args']) self.test_loader = get_instance_v2( utils.data, 'DataLoader', self.test_subset, **cfg['dataset']['loader_args']) if state_dict is not None: self.start_epoch = state_dict['save_epoch']+1 self.epochs = self.config['training']['epochs'] - self.start_epoch # best accuracy self.best_accuracy = {'mrmse': 1e5, 'miou': 0} # lr_scheduler: # type: OneCycleLR # args: # max_lr: 0.6 if cfg['lr_scheduler']['type'] == 'OneCycleLR': self.lr_scheduler = get_instance_v2(optim.lr_scheduler, cfg['lr_scheduler']['type'], optimizer=self.optimizer, steps_per_epoch=len( self.train_loader), epochs=self.epochs, **cfg['lr_scheduler']['args']) else: self.lr_scheduler = get_instance_v2( optim.lr_scheduler, cfg['lr_scheduler']['type'], optimizer=self.optimizer, **cfg['lr_scheduler']['args']) if state_dict is not None: if 'scheduler' in state_dict: logger.info('Found lr_scheduler state') self.lr_scheduler.load_state_dict(state_dict['scheduler']) self.best_accuracy = state_dict['best_accuracy'] @abc.abstractclassmethod def train_epoch(self, epoch): pass @abc.abstractclassmethod def test_epoch(self, epoch): pass