Source code for ptmelt.nn_utils

import torch.nn as nn

from ptmelt.losses import MixtureDensityLoss


[docs] def get_activation(act_name: str): """ Utility method to get activation based on its name. Args: act_name (str): Name of the activation function. """ if act_name == "relu": return nn.ReLU() elif act_name == "leaky_relu": return nn.LeakyReLU() elif act_name == "elu": return nn.ELU() elif act_name == "selu": return nn.SELU() elif act_name == "swish": return nn.SiLU() elif act_name == "gelu": return nn.GELU() elif act_name == "sigmoid": return nn.Sigmoid() elif act_name == "tanh": return nn.Tanh() elif act_name == "linear" or act_name is None: return nn.Identity() elif act_name == "softmax": return nn.Softmax(dim=-1) else: raise ValueError(f"Unsupported activation function {act_name}")
[docs] def get_initializer(init_name: str): """Utility method to get initializer based on its name.""" if init_name == "glorot_uniform": return nn.init.xavier_uniform_ elif init_name == "glorot_normal": return nn.init.xavier_normal_ elif init_name == "he_uniform": return nn.init.kaiming_uniform_ elif init_name == "he_normal": return nn.init.kaiming_normal_ elif init_name == "normal": return nn.init.normal_ elif init_name == "uniform": return nn.init.uniform_ else: raise ValueError(f"Unsupported initializer {init_name}")
[docs] def get_loss_fn(loss_name: str): """Utility method to get loss function based on its name.""" if loss_name == "mse": return nn.MSELoss() elif loss_name == "mae": return nn.L1Loss() elif loss_name == "huber": return nn.SmoothL1Loss() elif loss_name == "nll": return nn.NLLLoss() elif loss_name == "poisson": return nn.PoissonNLLLoss() elif loss_name == "kl_div": return nn.KLDivLoss() elif loss_name == "mixture_density": return MixtureDensityLoss() else: raise ValueError(f"Unsupported loss function {loss_name}")