Source code for vathos.trainer.gpu_trainer

from .base_trainer import BaseTrainer, optimizer_to, scheduler_to
import vathos.data_loader as vdata_loader
from vathos.utils import setup_logger
import vathos.model.loss as vloss

import gc
import torch
from pathlib import Path
# from tqdm.auto import tqdm
from tqdm.notebook import tqdm, trange
import torch.optim as optim
import torch.utils as utils

logger = setup_logger(__name__)


[docs]class GPUTrainer(BaseTrainer): r''' GPUTrainer: Trains the vathos model on GPU see :class:`~vathos.trainer.BaseTrainer` for args Examples: >>> gpu_trainer = GPUTrainer(model, loss_fns, optimizer, cfg, train_subset, test_subset, state_dict=state_dict) >>> gpu_trainer.start_train() ''' def __init__(self, *args, **kwargs): super(GPUTrainer, self).__init__(*args, **kwargs) cfg = self.config # set the device to GPU:0 // we don't support multiple GPUs for now self.device = torch.device("cuda:0") self.writer.add_graph(self.model, (torch.randn(1, 6, 96, 96))) self.writer.flush() self.model = self.model.to(self.device) optimizer_to(self.optimizer, self.device) scheduler_to(self.lr_scheduler, self.device)
[docs] def train_epoch(self, epoch): r"""trains the model for one epoch Args: epoch: the epoch number Returns: Dict: miou, mrmse, seg_loss, depth_loss """ logger.info(f'=> Training Epoch {epoch}') # clear the cache before training this epoch gc.collect() torch.cuda.empty_cache() # pbar = tqdm(self.train_loader, dynamic_ncols=True) pbar = self.train_loader # set the model to training mode self.model.train() miou = 0 mrmse = 0 seg_loss = 0 depth_loss = 0 for batch_idx, data in enumerate(pbar): # move the data of the specific dataset to our `device` data = getattr(vdata_loader, self.config['dataset']['name']).apply_on_batch( data, lambda x: x.to(self.device) ) # zero out the gradients, we don't want to accumulate them self.optimizer.zero_grad() x = torch.cat([data['bg'], data['fg_bg']], dim=1) d_out, s_out = self.model(x) # calculate the losses l1 = self.seg_loss(s_out, data['fg_bg_mask']) l2 = self.depth_loss(d_out, data['depth_fg_bg']) loss = self.comb_loss(l1, l2) with torch.no_grad(): miou += vloss.iou(s_out, data['fg_bg_mask']) mrmse += vloss.rmse(d_out, data['depth_fg_bg']) # update the gradients loss.backward() # step the optmizer self.optimizer.step() # step the scheduler if isinstance(self.lr_scheduler, optim.lr_scheduler.OneCycleLR): self.lr_scheduler.step() seg_loss += l1.item() depth_loss += l2.item() # pbar.set_description( # desc=f'loss={loss.item():.4f} seg_loss={l1.item():.4f} depth_loss={l2.item():.4f} batch_id={batch_idx}') self.writer.add_scalar( 'BatchLoss/Train/seg_loss', l1.item(), epoch*len(pbar) + batch_idx) self.writer.add_scalar( 'BatchLoss/Train/depth_loss', l2.item(), epoch*len(pbar) + batch_idx) seg_loss /= len(pbar) depth_loss /= len(pbar) miou /= len(pbar) mrmse /= len(pbar) logger.info( f'seg_loss: {seg_loss}, depth_loss: {depth_loss}, mIOU: {miou}, mRMSE: {mrmse}') self.writer.flush() return {'miou': miou, 'mrmse': mrmse, 'seg_loss': seg_loss, 'depth_loss': depth_loss}
[docs] def test_epoch(self, epoch): r"""tests the model for one epoch Args: epoch: the epoch number Returns: Dict: miou, mrmse, seg_loss, depth_loss """ logger.info(f'=> Testing Epoch {epoch}') # clear the cache before testing this epoch gc.collect() torch.cuda.empty_cache() # set the model in eval mode self.model.eval() # metrics and losses miou = 0 mrmse = 0 seg_loss = 0 depth_loss = 0 # tqdm writes a lot of data into a single cell in colab that caushes high local browser # ram uses, so chuck tqdm, find some alternative ? # pbar = tqdm(self.test_loader, dynamic_ncols=True) pbar = self.test_loader for batch_idx, data in enumerate(pbar): # move the data of the specific dataset to our `device` data = getattr(vdata_loader, self.config['dataset']['name']).apply_on_batch( data, lambda x: x.to(self.device) ) x = torch.cat([data['bg'], data['fg_bg']], dim=1) with torch.no_grad(): d_out, s_out = self.model(x) miou += vloss.iou(s_out, data['fg_bg_mask']) mrmse += vloss.rmse(d_out, data['depth_fg_bg']) l1 = self.seg_loss(s_out, data['fg_bg_mask']) l2 = self.depth_loss(d_out, data['depth_fg_bg']) seg_loss += l1.item() depth_loss += l2.item() # pbar.set_description(desc=f'testing batch_id={batch_idx}') miou /= len(pbar) mrmse /= len(pbar) seg_loss /= len(pbar) depth_loss /= len(pbar) logger.info(f'mIOU: {miou} mRMSE: {mrmse}') results = {**data, 'pred_depth': d_out, 'pred_mask': s_out} return {'miou': miou, 'mrmse': mrmse, 'seg_loss': seg_loss, 'depth_loss': depth_loss, 'results': results}
[docs] def start_train(self): r"""trains the model for self.epochs times the model and training state is saved at every epoch summary is flushed to disk every epoch """ logger.info('=> Training Started') logger.info(f'Training the model for {self.epochs} epochs') for epoch in range(self.start_epoch, self.epochs): if self.lr_scheduler: lr_value = [group['lr'] for group in self.optimizer.param_groups][0] logger.info(f'=> LR was set to {lr_value}') self.writer.add_scalar('LR/lr_value', lr_value, epoch) # train this epoch train_metric = self.train_epoch(epoch) # train metrics self.writer.add_scalar( 'EpochLoss/Train/seg_loss', train_metric['seg_loss'], epoch) self.writer.add_scalar( 'EpochLoss/Train/depth_loss', train_metric['depth_loss'], epoch) self.writer.add_scalar( 'EpochAccuracy/Train/mIOU', train_metric['miou'], epoch) self.writer.add_scalar( 'EpochAccuracy/Train/mRMSE', train_metric['mrmse'], epoch) # test this epoch test_metric = self.test_epoch(epoch) # test metrics self.writer.add_scalar( 'EpochLoss/Test/seg_loss', test_metric['seg_loss'], epoch) self.writer.add_scalar( 'EpochLoss/Test/depth_loss', test_metric['depth_loss'], epoch) self.writer.add_scalar( 'EpochAccuracy/Test/mIOU', test_metric['miou'], epoch) self.writer.add_scalar( 'EpochAccuracy/Test/mRMSE', test_metric['mrmse'], epoch) test_images = getattr(vdata_loader, self.config['dataset']['name']).plot_results( test_metric['results']) self.writer.add_figure( 'ModelImages/TestImages', test_images, epoch) # make sure to flush the data to the `SummaryWriter` file self.writer.flush() # check if we improved accuracy and save the model if (test_metric['mrmse'] <= self.best_accuracy['mrmse']) or (test_metric['miou'] >= self.best_accuracy['miou']): self.best_accuracy['mrmse'] = test_metric['mrmse'] self.best_accuracy['miou'] = test_metric['miou'] logger.info('=> Accuracy improved, saving best checkpoint ...') chkpt_path = Path(self.config['chkpt_dir']) chkpt_path.mkdir(parents=True, exist_ok=True) model_checkpoint = chkpt_path / 'model_checkpoint_best.pt' train_checkpoint = chkpt_path / 'train_checkpoint_best.pt' torch.save(self.model.state_dict(), model_checkpoint) torch.save({ 'optimizer': self.optimizer.state_dict(), 'scheduler': self.lr_scheduler.state_dict(), 'best_accuracy': self.best_accuracy, 'save_epoch': epoch, 'total_epochs': self.epochs }, train_checkpoint) logger.info('=> Saving checkpoint ...') chkpt_path = Path(self.config['chkpt_dir']) chkpt_path.mkdir(parents=True, exist_ok=True) model_checkpoint = chkpt_path / 'model_checkpoint.pt' train_checkpoint = chkpt_path / 'train_checkpoint.pt' torch.save(self.model.state_dict(), model_checkpoint) torch.save({ 'optimizer': self.optimizer.state_dict(), 'scheduler': self.lr_scheduler.state_dict(), 'best_accuracy': self.best_accuracy, 'save_epoch': epoch, 'total_epochs': self.epochs }, train_checkpoint)