Source code for ptmelt.blocks

import warnings
from typing import Any, List, Optional

import torch
import torch.nn as nn

from ptmelt.layers import MELTBatchNorm
from ptmelt.nn_utils import get_activation, get_initializer


[docs] class MELTBlock(nn.Module): """ Base class for a MELT block. Provides the building blocks for the MELT architecture. Defines the common parameters for the MELT blocks with optional activation, dropout, batch normalization, and batch renormalization layers. Args: input_features (int): Number of input features. node_list (List[int]): List of number of nodes in each layer. activation (str, optional): Activation function. Defaults to "relu". dropout (float, optional): Dropout rate. Defaults to 0.0. batch_norm (bool, optional): Whether to use batch normalization. Defaults to False. batch_norm_type (str, optional): Type of batch normalization. Defaults to "ema". use_batch_renorm (bool, optional): Whether to use batch renormalization. Defaults to False. initializer (str, optional): Weight initializer. Defaults to "glorot_uniform". **kwargs: Additional keyword arguments. """ def __init__( self, input_features: int, node_list: List[int], activation: Optional[str] = "relu", dropout: Optional[float] = 0.0, batch_norm: Optional[bool] = False, batch_norm_type: Optional[str] = "ema", use_batch_renorm: Optional[bool] = False, initializer: Optional[str] = "glorot_uniform", **kwargs: Any, ): super(MELTBlock, self).__init__(**kwargs) self.input_features = input_features self.node_list = node_list self.activation = activation self.dropout = dropout self.batch_norm = batch_norm self.batch_norm_type = batch_norm_type self.use_batch_renorm = use_batch_renorm self.initializer = initializer # Get the initializer function self.initializer_fn = get_initializer(self.initializer) # Create layer dictionary self.layer_dict = nn.ModuleDict() # Number of layers in the block self.num_layers = len(self.node_list) # Validate dropout value if self.dropout is not None: assert 0.0 <= self.dropout < 1.0, "Dropout must be in the range [0, 1)." # Get the activation layers if self.activation: self.layer_dict.update( { f"activation_{i}": get_activation(self.activation) for i in range(self.num_layers) } ) # Optional dropout layers if self.dropout > 0: self.layer_dict.update( { f"dropout_{i}": nn.Dropout(p=self.dropout) for i in range(self.num_layers) } ) # Optional batch normalization layers if self.batch_norm: if self.batch_norm_type == "pytorch": self.layer_dict.update( { f"batch_norm_{i}": nn.BatchNorm1d( num_features=self.node_list[i], affine=True, track_running_stats=True, momentum=1e-2, eps=1e-3, ) for i in range(self.num_layers) } ) else: self.layer_dict.update( { f"batch_norm_{i}": MELTBatchNorm( num_features=self.node_list[i], affine=True, track_running_stats=True, average_type=self.batch_norm_type, momentum=1e-2, eps=1e-3, ) for i in range(self.num_layers) } )
[docs] class DenseBlock(MELTBlock): """ Dense block for the MELT architecture. The dense block consists of dense layers with optional activation, dropout, and batch normalization layers. Args: **kwargs: Additional keyword arguments. """ def __init__( self, **kwargs: Any, ): super(DenseBlock, self).__init__(**kwargs) # Initialize dense layers self.layer_dict.update( { f"dense_{i}": nn.Linear( in_features=( self.input_features if i == 0 else self.node_list[i - 1] ), out_features=self.node_list[i], ) for i in range(self.num_layers) } ) # Initialize the weights [ self.initializer_fn(self.layer_dict[f"dense_{i}"].weight) for i in range(self.num_layers) ]
[docs] def forward(self, inputs: torch.Tensor): """Perform the forward pass of the dense block.""" x = inputs for i in range(self.num_layers): # dense -> batch norm -> activation -> dropout x = self.layer_dict[f"dense_{i}"](x) x = self.layer_dict[f"batch_norm_{i}"](x) if self.batch_norm else x x = self.layer_dict[f"activation_{i}"](x) if self.activation else x x = self.layer_dict[f"dropout_{i}"](x) if self.dropout > 0 else x return x
[docs] class ResidualBlock(MELTBlock): """ Residual block for the MELT architecture. The residual block consists of residual connections between dense layers with optional activation, dropout, and batch normalization layers. Residual connections are added after every `layers_per_block` layers. Args: layers_per_block (int, optional): Number of layers per residual block. Defaults to 2. pre_activation (bool, optional): Whether to use pre-activation residual blocks. Defaults to False. post_add_activation (bool, optional): Whether to use post-addition activation. Defaults to False. **kwargs: Additional keyword arguments. """ def __init__( self, layers_per_block: Optional[int] = 2, pre_activation: Optional[bool] = False, post_add_activation: Optional[bool] = False, **kwargs: Any, ): super(ResidualBlock, self).__init__(**kwargs) self.layers_per_block = layers_per_block self.pre_activation = pre_activation self.post_add_activation = post_add_activation # Warning if the number of layers is not divisible by layers_per_block if self.num_layers % self.layers_per_block != 0: warnings.warn( f"Warning: Number of layers {self.num_layers} is not divisible by " f"layers_per_block ({self.layers_per_block}), so the last block will " f"have {self.num_layers % self.layers_per_block} layers." ) # Initialize dense layers self.layer_dict.update( { f"dense_{i}": nn.Linear( in_features=( self.input_features if i == 0 else self.node_list[i - 1] ), out_features=self.node_list[i], ) for i in range(self.num_layers) } ) # Initialize the weights [ self.initializer_fn(self.layer_dict[f"dense_{i}"].weight) for i in range(self.num_layers) ] # Optional activation layer after addition if self.post_add_activation: self.layer_dict.update( { f"post_add_act_{i}": get_activation(self.activation) for i in range(self.num_layers // 2) } )
[docs] def forward(self, inputs: torch.Tensor): """Perform the forward pass of the residual block.""" x = inputs for i in range(self.num_layers): y = x # dense -> (pre-activation) -> batch norm -> dropout -> (post-activation) x = self.layer_dict[f"dense_{i}"](x) x = self.layer_dict[f"activation_{i}"](x) if self.pre_activation else x x = self.layer_dict[f"batch_norm_{i}"](x) if self.batch_norm else x x = self.layer_dict[f"dropout_{i}"](x) if self.dropout > 0 else x x = self.layer_dict[f"activation_{i}"](x) if not self.pre_activation else x # Add the residual connection when reaching the end of a residual block if (i + 1) % self.layers_per_block == 0 or i == self.num_layers - 1: x = x + y x = ( self.layer_dict[f"post_add_act_{i // self.layers_per_block}"](x) if self.post_add_activation else x ) return x
[docs] class DefaultOutput(nn.Module): """ Default output layer with a single dense layer and optional activation function. Args: input_features (int): Number of input features. output_features (int): Number of output features. activation (str, optional): Activation function. Defaults to "linear". initializer (str, optional): Weight initializer. Defaults to "glorot_uniform". **kwargs: Additional keyword arguments. """ def __init__( self, input_features: int, output_features: int, activation: Optional[str] = "linear", initializer: Optional[str] = "glorot_uniform", **kwargs: Any, ): super(DefaultOutput, self).__init__(**kwargs) self.input_features = input_features self.output_features = output_features self.activation = activation self.initializer = initializer # Get the initializer function self.initializer_fn = get_initializer(self.initializer) # Initialize output layer self.output_layer = nn.Linear( in_features=self.input_features, out_features=self.output_features ) # Initialize the weights self.initializer_fn(self.output_layer.weight) # Initialize activation layer self.activation_layer = get_activation(self.activation)
[docs] def forward(self, inputs: torch.Tensor): """Perform the forward pass of the default output layer.""" x = self.output_layer(inputs) x = self.activation_layer(x) return x
[docs] class MixtureDensityOutput(nn.Module): """ Output layer for mixture density networks. The output layer consists of three dense layers for the mixture coefficients, mean, and log variance of the output distribution. Args: input_features (int): Number of input features. num_mixtures (int): Number of mixture components. num_outputs (int): Number of output dimensions. activation (str, optional): Activation function. Defaults to "linear". initializer (str, optional): Weight initializer. Defaults to "glorot_uniform". **kwargs: Additional keyword arguments. """ def __init__( self, input_features: int, num_mixtures: int, num_outputs: int, activation: Optional[str] = "linear", initializer: Optional[str] = "glorot_uniform", **kwargs: Any, ): super(MixtureDensityOutput, self).__init__(**kwargs) self.input_features = input_features self.num_mixtures = num_mixtures self.num_outputs = num_outputs self.activation = activation self.initializer = initializer # Get the initializer function self.initializer_fn = get_initializer(self.initializer) # Initialize output layers self.mix_coeffs_layer = nn.Linear( in_features=self.input_features, out_features=self.num_mixtures ) self.mean_layer = nn.Linear( in_features=self.input_features, out_features=self.num_mixtures * self.num_outputs, ) self.log_var_layer = nn.Linear( in_features=self.input_features, out_features=self.num_mixtures * self.num_outputs, ) # Initialize the weights self.initializer_fn(self.mix_coeffs_layer.weight) self.initializer_fn(self.mean_layer.weight) self.initializer_fn(self.log_var_layer.weight) # Initialize activation layer self.activation_layer = get_activation(self.activation) self.softmax_layer = get_activation("softmax")
[docs] def forward(self, inputs: torch.Tensor): """Perform the forward pass of the multiple mixture output layer.""" mix_coeffs = self.mix_coeffs_layer(inputs) mix_coeffs = self.softmax_layer(mix_coeffs) mean = self.mean_layer(inputs) mean = self.activation_layer(mean) log_var = self.log_var_layer(inputs) log_var = self.activation_layer(log_var) # return concatenated output return torch.cat([mix_coeffs, mean, log_var], dim=-1)