Source code for ptmelt.models

import warnings
from contextlib import nullcontext
from itertools import groupby
from typing import List, Optional

import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from tqdm import tqdm

from ptmelt.blocks import (
    BayesianBlock,
    DefaultOutput,
    DenseBlock,
    MixtureDensityOutput,
    ResidualBlock,
)
from ptmelt.layers import AttentionPool
from ptmelt.losses import MixtureDensityLoss


[docs] class MELTModel(nn.Module): """ PT-MELT Base model. Args: num_features (int): The number of input features. num_outputs (int): The number of output units. width (int, optional): The width of the hidden layers. Defaults to 32. depth (int, optional): The number of hidden layers. Defaults to 2. act_fun (str, optional): The activation function to use. Defaults to 'relu'. dropout (float, optional): The dropout rate. Defaults to 0.0. input_dropout (float, optional): The input dropout rate. Defaults to 0.0. batch_norm (bool, optional): Whether to use batch normalization. Defaults to False. batch_norm_type (str, optional): The type of batch normalization to use. Defaults to 'ema'. use_batch_renorm (bool, optional): Whether to use batch renormalization. Defaults to False. output_activation (str, optional): The activation function for the output layer. Defaults to None. initializer (str, optional): The weight initializer to use. Defaults to 'glorot_uniform'. l1_reg (float, optional): The L1 regularization strength. Defaults to 0.0. l2_reg (float, optional): The L2 regularization strength. Defaults to 0.0. num_mixtures (int, optional): The number of mixture components for MDN. Defaults to 0. node_list (list, optional): The list of nodes per layer to alternately define layers. Defaults to None. **kwargs: Additional keyword arguments. """ def __init__( self, num_features: int, num_outputs: int, width: Optional[int] = 32, depth: Optional[int] = 2, act_fun: Optional[str] = "relu", dropout: Optional[float] = 0.0, input_dropout: Optional[float] = 0.0, batch_norm: Optional[bool] = False, batch_norm_type: Optional[str] = "ema", use_batch_renorm: Optional[bool] = False, output_activation: Optional[str] = None, initializer: Optional[str] = "glorot_uniform", l1_reg: Optional[float] = 0.0, l2_reg: Optional[float] = 0.0, num_mixtures: Optional[int] = 0, node_list: Optional[list] = None, seed: Optional[int] = None, **kwargs, ): super(MELTModel, self).__init__(**kwargs) self.num_features = num_features self.num_outputs = num_outputs self.width = width self.depth = depth self.act_fun = act_fun self.dropout = dropout self.input_dropout = input_dropout self.batch_norm = batch_norm self.batch_norm_type = batch_norm_type self.use_batch_renorm = use_batch_renorm self.output_activation = output_activation self.initializer = initializer self.l1_reg = l1_reg self.l2_reg = l2_reg self.num_mixtures = num_mixtures self.node_list = node_list self.seed = seed # Determine if network should be defined based on depth/width or node_list if self.node_list: self.num_layers = len(self.node_list) self.layer_width = self.node_list elif self.depth is None: self.num_layers = 0 self.layer_width = [] else: self.num_layers = self.depth self.layer_width = [self.width for i in range(self.depth)] # Create list for storing names of sub-layers self.sub_layer_names = [] # Create layer dictionary self.layer_dict = nn.ModuleDict()
[docs] def build(self): """Build the model.""" self.initialize_layers()
[docs] def initialize_layers(self): """Initialize the layers of the model.""" self.create_dropout_layers() self.create_output_layer()
[docs] def create_dropout_layers(self): """Create the dropout layers.""" if self.input_dropout > 0: self.layer_dict.update({"input_dropout": nn.Dropout(p=self.input_dropout)})
[docs] def create_output_layer(self): """Create the output layer.""" if self.num_mixtures > 0: self.layer_dict.update( { "output": MixtureDensityOutput( input_features=( self.layer_width[-1] if self.num_layers > 0 else self.num_features ), num_mixtures=self.num_mixtures, num_outputs=self.num_outputs, activation=self.output_activation, initializer=self.initializer, seed=self.seed, ) } ) self.sub_layer_names.append("output") else: self.layer_dict.update( { "output": DefaultOutput( input_features=( self.layer_width[-1] if self.num_layers > 0 else self.num_features ), output_features=self.num_outputs, activation=self.output_activation, initializer=self.initializer, seed=self.seed, ) } ) self.sub_layer_names.append("output")
[docs] def compute_jacobian(self, x): """Compute the Jacobian of the model with respect to the input.""" pass
[docs] def l1_regularization(self, lambda_l1: float): """ Compute the L1 regularization term for use in the loss function. Args: lambda_l1 (float): The L1 regularization strength. """ l1_norm = sum( p.abs().sum() for name, p in self.named_parameters() if p.requires_grad and "weight" in name ) return lambda_l1 * l1_norm
[docs] def l2_regularization(self, lambda_l2: float): """ Compute the L2 regularization term for use in the loss function. Args: lambda_l2 (float): The L2 regularization strength. """ l2_norm = sum( p.pow(2.0).sum() for name, p in self.named_parameters() if p.requires_grad and "weight" in name ) return 0.5 * lambda_l2 * l2_norm
[docs] def get_loss_fn( self, loss: Optional[str] = "mse", reduction: Optional[str] = "mean", mse_weight: Optional[float] = None, ): """ Get the loss function for the model. Used in the training loop. Args: loss (str, optional): The loss function to use. Defaults to 'mse'. reduction (str, optional): The reduction method for the loss. Defaults to 'mean'. """ if self.num_mixtures > 0: warnings.warn( "Mixture Density Networks require the use of the MixtureDensityLoss " "class. The loss function will be set to automatically." ) return MixtureDensityLoss( num_mixtures=self.num_mixtures, num_outputs=self.num_outputs, mse_weight=mse_weight if mse_weight else 0.0, ) else: # mappings for common loss functions common_mappings = { "mse": "MSELoss", "mae": "L1Loss", "huber": "SmoothL1Loss", "nll": "NLLLoss", "poisson": "PoissonNLLLoss", "kl_div": "KLDivLoss", } loss = common_mappings.get(loss.lower(), loss) return getattr(nn, loss)(reduction=reduction)
[docs] def get_optimizer(self, optimizer_name: str, **kwargs): """ Get the optimizer for the model. Used in the training loop. Args: optimizer_name (str): The name of the optimizer to use. """ name = optimizer_name.lower() mapping = { "sgd": torch.optim.SGD, "adam": torch.optim.Adam, "adamw": torch.optim.AdamW, "rmsprop": torch.optim.RMSprop, "adadelta": torch.optim.Adadelta, "adagrad": torch.optim.Adagrad, "adamax": torch.optim.Adamax, "nadam": torch.optim.NAdam, "radam": torch.optim.RAdam, } if name not in mapping: raise ValueError(f"Unknown optimizer '{optimizer_name}'.") return mapping[name](self.parameters(), **kwargs)
[docs] def get_scheduler(self, scheduler_name: str, optimizer, **kwargs): """ Get the learning rate scheduler for the model. Used in the training loop. Args: scheduler_name (str): The name of the scheduler to use. optimizer: The optimizer to attach the scheduler to. """ if "min_lr" in kwargs: self.min_lr = kwargs.pop("min_lr") return getattr(lr_scheduler, scheduler_name)(optimizer, **kwargs)
[docs] def step(self, dataloader, optimizer, criterion, device="cpu", training=True): """ Perform a single step either in training or validation mode. """ self.train() if training else self.eval() # Use torch.no_grad() only if not training context_manager = torch.no_grad() if not training else nullcontext() running_loss = 0.0 with context_manager: for x_in, y_in in dataloader: # Move data to device x_in, y_in = x_in.to(device), y_in.to(device) # Forward pass pred = self(x_in) loss = criterion(pred, y_in) if training: # Add L1 and L2 regularization if present if self.l1_reg > 0: loss += self.l1_regularization(lambda_l1=self.l1_reg) if self.l2_reg > 0: loss += self.l2_regularization(lambda_l2=self.l2_reg) # Zero the parameter gradients optimizer.zero_grad() # Backward pass loss.backward() # Optimize optimizer.step() # Accumulate running loss running_loss += loss.item() # Normalize loss running_loss /= len(dataloader) return running_loss
[docs] def fit( self, train_dl, val_dl, optimizer, criterion, num_epochs: Optional[int] = 100, device: Optional[str] = "cpu", scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, stopping: Optional[bool] = True, verbose=False, ): """ Perform the model training loop. Args: train_dl (DataLoader): The training data loader. val_dl (DataLoader): The validation data loader. optimizer (Optimizer): The optimizer to use. criterion (Loss): The loss function to use. num_epochs (int): The number of epochs to train the model. device (str, optional): The device to use for training. Defaults to 'cpu'. verbose (bool, optional): Whether to print training statistics. Defaults to False. """ # Move model to device self.to(device) # Create history dictionary if not hasattr(self, "history"): self.history = {"loss": [], "val_loss": [], "lr": [], "epoch": []} for epoch in tqdm(range(num_epochs), disable=not verbose): # Perform a training and validation step train_loss = self.step( train_dl, optimizer, criterion, device=device, training=True ) val_loss = self.step( val_dl, optimizer, criterion, device=device, training=False ) # Step the scheduler if provided if scheduler: scheduler.step( val_loss if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) else None ) # Print statistics if (epoch + 1) % 10 == 0 and verbose: print( f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, " f"Val Loss: {val_loss:.4f}" ) # Save history self.history["loss"].append(train_loss) self.history["val_loss"].append(val_loss) self.history["lr"].append( scheduler.get_last_lr()[0] if scheduler and hasattr(scheduler, "get_last_lr") else ( optimizer.param_groups[0]["lr"] if isinstance(optimizer, torch.optim.Optimizer) else optimizer.defaults["lr"] ) ) self.history["epoch"].append(epoch + 1) if self.min_lr and stopping: # Check if the last learning rate is less than or equal to the minimum learning rate if scheduler and hasattr(scheduler, "get_last_lr"): if scheduler.get_last_lr()[0] <= self.min_lr: if verbose: print( f"Stopping training at epoch {epoch + 1} due to " f"learning rate reaching minimum {self.min_lr}." ) break
[docs] class ArtificialNeuralNetwork(MELTModel): """ Artificial Neural Network (ANN) model. Args: **kwargs: Additional keyword arguments. """ def __init__( self, **kwargs, ): super(ArtificialNeuralNetwork, self).__init__(**kwargs)
[docs] def initialize_layers(self): """Initialize the layers of the ANN.""" super(ArtificialNeuralNetwork, self).initialize_layers() # Bulk layers self.layer_dict.update( { "dense_block": DenseBlock( input_features=self.num_features, node_list=self.layer_width, activation=self.act_fun, dropout=self.dropout, batch_norm=self.batch_norm, batch_norm_type=self.batch_norm_type, use_batch_renorm=self.use_batch_renorm, initializer=self.initializer, seed=self.seed, ) } ) self.sub_layer_names.append("dense_block")
[docs] def forward(self, inputs: torch.Tensor): """ Perform the forward pass of the ANN. Args: inputs (torch.Tensor): The input data. """ # Apply input dropout x = ( self.layer_dict["input_dropout"](inputs) if self.input_dropout > 0 else inputs ) # Apply dense block x = self.layer_dict["dense_block"](x) # Apply the output layer(s) and return return self.layer_dict["output"](x)
[docs] class ResidualNeuralNetwork(MELTModel): """ Residual Neural Network (ResNet) model. Args: layers_per_block (int, optional): The number of layers per residual block. Defaults to 2. pre_activation (bool, optional): Whether to use pre-activation. Defaults to True. post_add_activation (bool, optional): Whether to use activation after addition. Defaults to False. **kwargs: Additional keyword arguments. """ def __init__( self, layers_per_block: Optional[int] = 2, pre_activation: Optional[bool] = True, post_add_activation: Optional[bool] = False, **kwargs, ): super(ResidualNeuralNetwork, self).__init__(**kwargs) self.layers_per_block = layers_per_block self.pre_activation = pre_activation self.post_add_activation = post_add_activation
[docs] def build(self): """Build the model.""" if self.depth % self.layers_per_block != 0: warnings.warn( f"Warning: depth {self.num_layers} is not divisible by " f"layers_per_block ({self.layers_per_block}), so the last block will " f"have {self.depth % self.layers_per_block} layers." ) self.initialize_layers() super(ResidualNeuralNetwork, self).build()
[docs] def initialize_layers(self): """Initialize the layers of the ResNet.""" super(ResidualNeuralNetwork, self).initialize_layers() # Create the Residual Block self.layer_dict.update( { "residual_block": ResidualBlock( layers_per_block=self.layers_per_block, pre_activation=self.pre_activation, post_add_activation=self.post_add_activation, input_features=self.num_features, node_list=self.layer_width, activation=self.act_fun, dropout=self.dropout, batch_norm=self.batch_norm, batch_norm_type=self.batch_norm_type, use_batch_renorm=self.use_batch_renorm, initializer=self.initializer, seed=self.seed, ) } ) self.sub_layer_names.append("residual_block")
[docs] def forward(self, inputs: torch.Tensor): """ Perform the forward pass of the ResNet. Args: inputs (torch.Tensor): The input data. """ # Apply input dropout x = ( self.layer_dict["input_dropout"](inputs) if self.input_dropout > 0 else inputs ) # Apply residual block x = self.layer_dict["residual_block"](x) # Apply the output layer(s) and return return self.layer_dict["output"](x)
[docs] class BayesianNeuralNetwork(MELTModel): """ Bayesian Neural Network (BNN) model. Args: num_points (int, optional): Number of Monte Carlo samples. Defaults to 1. do_aleatoric (bool, optional): Flag to perform aleatoric output. Defaults to False. do_bayesian_output (bool, optional): Flag to perform Bayesian output. Defaults to True. aleatoric_scale_factor (float, optional): Scale factor for aleatoric uncertainty. Defaults to 5e-2. scale_epsilon (float, optional): Epsilon value for the scale of the aleatoric uncertainty. Defaults to 1e-3. bayesian_mask (list, optional): List of booleans to determine which layers are Bayesian and which are Dense. Defaults to None. **kwargs: Additional keyword arguments. """ def __init__( self, num_points: Optional[int] = 1, do_aleatoric: Optional[bool] = False, do_bayesian_output: Optional[bool] = True, aleatoric_scale_factor: Optional[float] = 5e-2, scale_epsilon: Optional[float] = 1e-3, bayesian_mask: Optional[List[bool]] = None, **kwargs, ): super(BayesianNeuralNetwork, self).__init__(**kwargs) self.num_points = num_points self.do_aleatoric = do_aleatoric self.do_bayesian_output = do_bayesian_output self.aleatoric_scale_factor = aleatoric_scale_factor self.scale_epsilon = scale_epsilon self.bayesian_mask = bayesian_mask
[docs] def create_output_layer(self): """Create output layer for the Bayesian Neural Network.""" if self.num_mixtures > 0: self.layer_dict.update( { "output": MixtureDensityOutput( input_features=( self.layer_width[-1] if self.num_layers > 0 else self.num_features ), num_mixtures=self.num_mixtures, num_outputs=self.num_outputs, activation=self.output_activation, initializer=self.initializer, seed=self.seed, ) } ) self.sub_layer_names.append("output") else: self.layer_dict.update( { "output": DefaultOutput( input_features=( self.layer_width[-1] if self.num_layers > 0 else self.num_features ), output_features=self.num_outputs, activation=self.output_activation, initializer=self.initializer, do_bayesian=self.do_bayesian_output, seed=self.seed, ) } ) self.sub_layer_names.append("output")
[docs] def build(self): """Build the BNN.""" self.initialize_layers() super(BayesianNeuralNetwork, self).build()
[docs] def initialize_layers(self): """Initialize the layers of the BNN.""" super(BayesianNeuralNetwork, self).initialize_layers() # Create the Bayesian and Dense blocks based on the mask if self.bayesian_mask is None: self.num_dense_layers = 0 self.dense_block = None self.bayesian_block = BayesianBlock( num_points=self.num_points, input_features=self.num_features, node_list=self.layer_width, activation=self.act_fun, dropout=self.dropout, batch_norm=self.batch_norm, batch_norm_type=self.batch_norm_type, use_batch_renorm=self.use_batch_renorm, initializer=self.initializer, seed=self.seed, ) self.layer_dict.update({"full_bayesian_block": self.bayesian_block}) self.sub_layer_names.append("full_bayesian_block") else: self.dense_block = [] self.bayesian_block = [] bayes_block_idx = 0 dense_block_idx = 0 # Loop through the Bayesian mask and create the blocks idx = 0 for is_bayesian, group in groupby(self.bayesian_mask): # Get the group and layer width group_list = list(group) group_len = len(group_list) layer_width = self.layer_width[idx : idx + group_len] idx += group_len # Create a Bayesian block or Dense block if is_bayesian: bayesian_block = BayesianBlock( num_points=self.num_points, input_features=( self.num_features if bayes_block_idx == 0 else layer_width[0] ), node_list=layer_width, activation=self.act_fun, dropout=self.dropout, batch_norm=self.batch_norm, batch_norm_type=self.batch_norm_type, use_batch_renorm=self.use_batch_renorm, initializer=self.initializer, seed=self.seed, ) self.bayesian_block.append(bayesian_block) self.layer_dict.update( {f"bayesian_block_{bayes_block_idx}": bayesian_block} ) self.sub_layer_names.append(f"bayesian_block_{bayes_block_idx}") bayes_block_idx += 1 else: dense_block = DenseBlock( input_features=( self.num_features if dense_block_idx == 0 else layer_width[0] ), node_list=layer_width, activation=self.act_fun, dropout=self.dropout, batch_norm=self.batch_norm, batch_norm_type=self.batch_norm_type, use_batch_renorm=self.use_batch_renorm, initializer=self.initializer, seed=self.seed, ) self.dense_block.append(dense_block) self.layer_dict.update( {f"dense_block_{dense_block_idx}": dense_block} ) self.sub_layer_names.append(f"dense_block_{dense_block_idx}") dense_block_idx += 1
[docs] def forward(self, inputs: torch.Tensor): """ Perform the forward pass of the BNN. Args: inputs (torch.Tensor): The input data. """ # Apply input dropout x = ( self.layer_dict["input_dropout"](inputs) if self.input_dropout > 0 else inputs ) # Apply the full Bayesian block if it exists if "full_bayesian_block" in self.layer_dict: x = self.layer_dict["full_bayesian_block"](x) else: # Apply each Bayesian and Dense block in sequence bayesian_index = 0 dense_index = 0 for is_bayesian in self.bayesian_mask: if is_bayesian: x = self.layer_dict[f"bayesian_block_{bayesian_index}"](x) bayesian_index += 1 else: x = self.layer_dict[f"dense_block_{dense_index}"](x) dense_index += 1 # Apply the output layer(s) and return return self.layer_dict["output"](x)
[docs] def step(self, dataloader, optimizer, criterion, device="cpu", training=True): """ Perform a single step either in training or validation mode. """ self.train() if training else self.eval() dataset_size = len(dataloader.dataset) # Use torch.no_grad() only if not training context_manager = torch.no_grad() if not training else nullcontext() running_loss = 0.0 with context_manager: for x_in, y_in in dataloader: # Move data to device x_in, y_in = x_in.to(device), y_in.to(device) # Forward pass pred = self(x_in) loss = criterion(pred, y_in) # Add in kl divergence for the Bayesian block if "full_bayesian_block" in self.layer_dict: loss += ( self.layer_dict["full_bayesian_block"].kl_loss() / dataset_size ) else: bayesian_index = 0 # Add in kl divergence for each Bayesian block for is_bayesian in self.bayesian_mask: if is_bayesian: loss += ( self.layer_dict[ f"bayesian_block_{bayesian_index}" ].kl_loss() / dataset_size ) bayesian_index += 1 if training: # Add L1 and L2 regularization if present if self.l1_reg > 0: loss += self.l1_regularization(lambda_l1=self.l1_reg) if self.l2_reg > 0: loss += self.l2_regularization(lambda_l2=self.l2_reg) # Zero the parameter gradients optimizer.zero_grad() # Backward pass loss.backward() # Optimize optimizer.step() # Accumulate running loss running_loss += loss.item() # Normalize loss running_loss /= len(dataloader) return running_loss
[docs] class RecurrentNeuralNetwork(MELTModel): """ Recurrent Neural Network (RNN) model. Bidirectional is not supported in this implementation as it is intended for forecasting tasks. Args: rnn_type (str, optional): The type of RNN to use ('rnn', 'lstm', 'gru'). return_sequences (bool, optional): Whether to return the full sequence or just the last output. **kwargs: Additional keyword arguments. """ def __init__( self, rnn_type: Optional[str] = "lstm", return_sequences: Optional[bool] = False, head_type: str = "last", **kwargs, ): super(RecurrentNeuralNetwork, self).__init__(**kwargs) self.rnn_type = rnn_type.lower() if self.rnn_type not in ["rnn", "lstm", "gru"]: raise ValueError(f"RNN type must be 'rnn', 'lstm', or 'gru'.") self.return_sequences = return_sequences if self.return_sequences: warnings.warn( "Returning sequences is not implemented for RNNs. Please set return_sequences=False." ) raise NotImplementedError( "Returning sequences is not implemented for RNNs." ) self.head_type = head_type.lower() if self.head_type not in ["last", "attn", "mean", "max"]: raise ValueError("head_type must be one of: 'last', 'attn', 'mean', 'max'.") if self.node_list is not None: warnings.warn( "Warning: node_list for RNN must be uniform per layer;" " using width and depth to define layers." ) self.hidden_size = self.width self.num_layers = self.depth self.recurrent_out_dim = self.hidden_size
[docs] def create_output_layer(self): """ Override to use recurrent_out_dim as input_features instead of layer_width. """ if self.num_mixtures > 0: head = MixtureDensityOutput( input_features=self.recurrent_out_dim, num_mixtures=self.num_mixtures, num_outputs=self.num_outputs, activation=self.output_activation, initializer=self.initializer, seed=self.seed, ) else: head = DefaultOutput( input_features=self.recurrent_out_dim, output_features=self.num_outputs, activation=self.output_activation, initializer=self.initializer, seed=self.seed, ) # TODO: Implement return sequences as option for many-to-many tasks if self.return_sequences: warnings.warn( "Returning sequences is not implemented for RNNs. Please set return_sequences=False." ) raise NotImplementedError( "Returning sequences is not implemented for RNNs." ) else: self.layer_dict.update({"output": head}) self.sub_layer_names.append("output")
[docs] def initialize_layers(self): """ Initialize dropout, rnn block, and output layers. """ super(RecurrentNeuralNetwork, self).initialize_layers() # Create the RNN layer (use PyTorch built-in) rnn_class = { "rnn": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU, }[self.rnn_type] self.layer_dict.update( { "rnn_block": rnn_class( input_size=self.num_features, hidden_size=self.hidden_size, num_layers=self.num_layers, batch_first=True, dropout=self.dropout if self.num_layers > 1 else 0.0, bidirectional=False, # Needs to be False for forecasting ) } ) self.sub_layer_names.append("rnn_block") if not self.return_sequences: if self.head_type == "attn": self.layer_dict["pool_head"] = AttentionPool(self.hidden_size) else: self.layer_dict["pool_head"] = None
def _select_last_timestep(self, rnn_out, lengths): """ Select the output from the last time step for each sequence in the batch. Used when attention is not selected. Classic many-to-one. Args: rnn_out (torch.Tensor): The output from the RNN layer. lengths (torch.Tensor): The lengths of the sequences in the batch. """ batch_size, time_steps, features = rnn_out.size() # Create a mask to select the last valid time step for each sequence idx = (lengths - 1).view(-1, 1).expand(batch_size, features).unsqueeze(1) return rnn_out.gather(1, idx).squeeze(1) def _compute_mean_timestep(self, rnn_out, lengths): """ Compute the mean over valid time steps for each sequence in the batch. Used when head_type is 'mean'. Useful for some tasks. Args: rnn_out (torch.Tensor): The output from the RNN layer. lengths (torch.Tensor): The lengths of the sequences in the batch. """ batch_size, time_steps, features = rnn_out.size() mask = ( torch.arange(time_steps, device=rnn_out.device) .unsqueeze(0) .expand(batch_size, time_steps) < lengths.unsqueeze(1) ).float() summed = (rnn_out * mask.unsqueeze(-1)).sum(dim=1) counts = mask.sum(dim=1).clamp_min(1.0).unsqueeze(-1) return summed / counts def _compute_max_timestep(self, rnn_out, lengths): """ Compute the max over valid time steps for each sequence in the batch. Used when head_type is 'max'. Useful for some tasks. Args: rnn_out (torch.Tensor): The output from the RNN layer. lengths (torch.Tensor): The lengths of the sequences in the batch. """ batch_size, time_steps, features = rnn_out.size() mask = torch.arange(time_steps, device=rnn_out.device).unsqueeze(0).expand( batch_size, time_steps ) < lengths.unsqueeze(1) masked_rnn_out = rnn_out.masked_fill(~mask.unsqueeze(-1), float("-inf")) return masked_rnn_out.max(dim=1).values def _random_suffix_crop(self, x, y, lengths, min_length=32): """ Randomly crops sequences from the end with varying lengths. This function is used for data augmentation during training which can help improve model robustness. Args: x (torch.Tensor): Input tensor of shape [batch_size, seq_length, feature_dim], where batch_size is the batch size, seq_length is the sequence length, and feature_dim is the feature dimension. y (torch.Tensor): Corresponding labels tensor of shape [batch_size, ...]. lengths (torch.Tensor): Tensor of shape [batch_size] representing the original lengths of each sequence in the batch. min_length (int, optional): Minimum length for the random crop. Defaults to 32. """ # x: [batch_size, seq_length, feature_dim] # y: [batch_size, ...] # lengths: [batch_size] batch_size, seq_length, feature_dim = x.shape new_lengths = torch.randint( low=min_length, high=seq_length + 1, size=(batch_size,) ) # build per-sample crops ending at seq_length x_out = torch.zeros_like(x) for i, crop_length in enumerate(new_lengths): x_out[i, :crop_length] = x[i, seq_length - crop_length : seq_length] return x_out, y, new_lengths
[docs] def forward(self, inputs: torch.Tensor, lengths: Optional[torch.Tensor] = None): """ Perform the forward pass of the RNN. If lengths are provided, the input sequences will be packed and unpacked to handle variable-length sequences. Performs optional pooling based on head_type setting. Args: inputs (torch.Tensor): The input data. lengths (torch.Tensor, optional): The lengths of the sequences in the batch. """ # Apply input dropout x = ( self.layer_dict["input_dropout"](inputs) if self.input_dropout > 0 else inputs ) # Pack the sequences if lengths are provided if lengths is not None: x = pack_padded_sequence( x, lengths.cpu(), batch_first=True, enforce_sorted=False ) # Apply RNN block rnn_out, _ = self.layer_dict["rnn_block"](x) # Unpack the sequences if they were packed if lengths is not None: rnn_out, _ = pad_packed_sequence( rnn_out, batch_first=True, total_length=inputs.size(1) ) if self.return_sequences: warnings.warn( "Returning sequences is not implemented for RNNs. Please set return_sequences=False." ) raise NotImplementedError( "Returning sequences is not implemented for RNNs." ) # return self.layer_dict["output"](rnn_out) if self.head_type == "attn": feat = self.layer_dict["pool_head"](rnn_out, lengths=lengths) elif self.head_type == "mean": if lengths is None: feat = rnn_out.mean(dim=1) else: feat = self._compute_mean_timestep(rnn_out, lengths) elif self.head_type == "max": if lengths is None: feat = rnn_out.max(dim=1).values else: feat = self._compute_max_timestep(rnn_out, lengths) else: # Sequence-to-one style where we take the last valid time step feat = ( self._select_last_timestep(rnn_out, lengths) if lengths is not None else rnn_out[:, -1, :] ) return self.layer_dict["output"](feat)
[docs] def step(self, dataloader, optimizer, criterion, device="cpu", training=True): """ Perform a single step either in training or validation mode. """ self.train() if training else self.eval() context_manager = torch.no_grad() if not training else nullcontext() running_loss = 0.0 with context_manager: for batch in dataloader: if len(batch) == 3: x_in, y_in, lengths = batch lengths = lengths.to(device) else: x_in, y_in = batch lengths = None # If training, perform random suffix cropping for data augmentation if training and lengths is not None and x_in.size(1) >= 32: x_in, y_in, lengths = self._random_suffix_crop( x_in, y_in, lengths, min_length=32 ) # Move data to device x_in, y_in = x_in.to(device), y_in.to(device) # Forward pass pred = self(x_in, lengths=lengths) if not self.return_sequences: # Standard loss calculation for sequence-to-one loss = criterion(pred, y_in) else: # TODO: Implement loss calculation for sequence-to-sequence raise NotImplementedError( "Loss calculation for sequence-to-sequence is not implemented." ) if training: # Add L1 and L2 regularization if present if self.l1_reg > 0: loss += self.l1_regularization(lambda_l1=self.l1_reg) if self.l2_reg > 0: loss += self.l2_regularization(lambda_l2=self.l2_reg) # Zero the parameter gradients optimizer.zero_grad() # Backward pass loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) # Optimize optimizer.step() # Accumulate running loss running_loss += loss.item() # Normalize loss running_loss /= len(dataloader) return running_loss