Source code for vathos.runner.runner

from vathos.utils import setup_logger, get_instance_v2
import vathos.model as vmodel
import vathos.model.loss as vloss
import vathos.trainer as vtrainer
import vathos.data_loader as vdata_loader
from vathos.data_loader.utils import split_dataset
import vathos.utils as vutils

import torch
import torchvision.transforms as T
from pprint import pformat
import pprint
from pathlib import Path

logger = setup_logger(__name__)


[docs]class Runner(): r"""Runner that encapsulated a Trainer Args: config: a dict that contains the current experiment info """ def __init__(self, config): self.config = config # print the super awesome logo print(vutils.logo) logger.info('Now simply setup_train and then start_train your model')
[docs] def setup_train(self): r"""sets up the training for the config provided """ cfg = self.config logger.info('Config') # print the config for line in pprint.pformat(cfg).split('\n'): logger.info(line) # dataset: # name: DenseDepth # root: vathos_data # zip_dir: "/content/gdrive/My Drive/DepthProject/depth_dataset_zipped/" # loader_args: # batch_size: 128 # num_workers: 4 # shuffle: True # pin_memory: True dataset = get_instance_v2( vdata_loader, cfg['dataset']['name'], root=cfg['dataset']['root'], source_zipfolder=cfg['dataset']['zip_dir'], transform=T.Compose([T.ToTensor()]), target_transform=T.Compose([T.ToTensor()]) ) train_subset, test_subset = split_dataset( dataset, div_factor=cfg['dataset']['div_factor']) # check if the train_subset and test_subset indices are present in disk Path(cfg['chkpt_dir']).mkdir(parents=True, exist_ok=True) subset_file = Path(cfg['chkpt_dir']) / 'subset.pt' if subset_file.exists(): # load the subset state logger.info('=> Found subset.pt loading indices') subset_state = torch.load(subset_file) train_subset.indices = subset_state['train_indices'] test_subset.indices = subset_state['test_indices'] else: # save the subset dict torch.save({'train_indices': train_subset.indices, 'test_indices': test_subset.indices}, subset_file) logger.info('=> Saved subset.pt (train, test indices)') # create the model model = get_instance_v2(vmodel, cfg['model']) # optimizer: # type: AdamW # args: # lr: 0.01 optimizer = get_instance_v2(torch.optim, ctor_name=cfg['optimizer']['type'], params=model.parameters( ), lr=cfg['optimizer']['args']['lr']) # seg_loss: BCEDiceLoss # depth_loss: RMSELoss seg_loss = get_instance_v2(vloss, ctor_name=cfg['seg_loss']) depth_loss = get_instance_v2(vloss, ctor_name=cfg['depth_loss']) loss_fns = (seg_loss, depth_loss) # check if the model init weights are specified # model_init: "models/model.pt" model_init = Path(cfg['model_init']) if model_init.exists(): logger.info('=> Found Model init weights') model_state_dict = torch.load(model_init) model.load_state_dict(model_state_dict) # load the last checkpoint # chkpt_dir: checkpoint model_checkpoint = Path(cfg['chkpt_dir']) / 'model_checkpoint.pt' train_checkpoint = Path(cfg['chkpt_dir']) / 'train_checkpoint.pt' if model_checkpoint.exists(): logger.info('=> Found model checkpoint') model_state_dict = torch.load(model_checkpoint) model.load_state_dict(model_state_dict) state_dict = None if train_checkpoint.exists(): logger.info('=> Found train checkpoint') checkpoint_state = torch.load(train_checkpoint) optimizer.load_state_dict(checkpoint_state['optimizer']) save_epoch = checkpoint_state['save_epoch'] total_epochs = checkpoint_state['total_epochs'] logger.info(f'Start Epoch should be {save_epoch}+1') state_dict = checkpoint_state else: logger.info('=> No saved checkpoints found') if cfg['device'] == 'GPU': self.trainer = get_instance_v2( vtrainer, 'GPUTrainer', model, loss_fns, optimizer, cfg, train_subset, test_subset, state_dict=state_dict) else: logger.error(f"Unsupported Device: {cfg['device']}")
[docs] def start_train(self): r"""a wrapper that calls self.trainer.start_train() """ assert(self.trainer is not None) self.trainer.start_train()