Source code for vathos.utils.config


import yaml

from typing import Any, List, Tuple, Dict
from types import ModuleType

import torch
import torch.nn as nn

from .logger import setup_logger

logger = setup_logger(__name__)


[docs]def load_config(filename: str) -> dict: r"""Load a configuration file as YAML and returns a dict Args: filename (str): location of the file Returns: (Dict) of the config """ with open(filename) as fh: config = yaml.safe_load(fh) return config
[docs]def setup_device(model: nn.Module, target_device: int) -> Tuple[torch.device, List[int]]: r"""sets up the device for the model Args: model (nn.Module): the model target_device (int): index of the target device Returns: Tuple[torch.device, List[int]] """ available_devices: List = list(range(torch.cuda.device_count())) logger.info( f'Using device {target_device} of available devices {available_devices}') device = torch.device(f'cuda:{target_device}') model = model.to(device) return model, device
[docs]def setup_param_groups(model: nn.Module, config: Dict) -> List: return [{'params': model.parameters(), **config}]
[docs]def get_instance(module: ModuleType, name: str, config: Dict, *args: Any) -> Any: r"""creates an instance from a constructor name and module name Args: module (ModuleType): the module which contains the class name (str): name of the class config (Dict): configuration of experiment args (Any): any arguments that needs to be passed to the class """ ctor_name = config[name]['type'] logger.info(f'Building: {module.__name__}.{ctor_name}') return getattr(module, ctor_name)(*args, **config[name]['args'])
[docs]def get_instance_v2(module, ctor_name, *args, **kwargs): r"""creates an instance from a constructor name and module name Args: module: the module which contains the ctor_name ctor_name: name of the constructor args(Optional): positional arguments that needs to be passed to ctor kwargs(Optional): keywords arguments that needs to be passed to ctor """ logger.info(f'Building {module.__name__}.{ctor_name}') return getattr(module, ctor_name)(*args, **kwargs)