import numpy as np import torch class BaseModule(torch.nn.Module): def __init__(self): super(BaseModule, self).__init__() @property def nparams(self): """ Returns number of trainable parameters of the module. """ num_params = 0 for name, param in self.named_parameters(): if param.requires_grad: num_params += np.prod(param.detach().cpu().numpy().shape) return num_params def relocate_input(self, x: list): """ Relocates provided tensors to the same device set for the module. """ device = next(self.parameters()).device for i in range(len(x)): if isinstance(x[i], torch.Tensor) and x[i].device != device: x[i] = x[i].to(device) return x