Source code for ptmelt.losses

import torch
import torch.nn.functional as F


[docs] def safe_exp(x): """Prevents overflow by clipping input range to reasonable values.""" x = torch.clamp(x, min=-10, max=10) return torch.exp(x)
[docs] class MixtureDensityLoss(torch.nn.Module): """ Custom loss function for Mixture Density Network (MDN). Args: num_mixtures (int): Number of mixture components. num_outputs (int): Number of output dimensions. """ def __init__(self, num_mixtures, num_outputs): super(MixtureDensityLoss, self).__init__() self.num_mixtures = num_mixtures self.num_outputs = num_outputs
[docs] def forward(self, y_pred, y_true): # Extract the mixture coefficients, means, and log-variances end_mixture = self.num_mixtures end_mean = end_mixture + self.num_mixtures * self.num_outputs end_log_var = end_mean + self.num_mixtures * self.num_outputs # coefficients -> (batch_size, num_mixtures) m_coeffs = y_pred[:, :end_mixture] # means -> (batch_size, num_mixtures * num_outputs) mean_preds = y_pred[:, end_mixture:end_mean] # log variances -> (batch_size, num_mixtures * num_outputs) log_var_preds = y_pred[:, end_mean:end_log_var] # Reshape mean predictions -> (batch_size, num_mixtures, num_outputs) mean_preds = mean_preds.view(-1, self.num_mixtures, self.num_outputs) # Reshape log variance predictions -> (batch_size, num_mixtures, num_outputs) log_var_preds = log_var_preds.view(-1, self.num_mixtures, self.num_outputs) # Ensure mixture coefficients sum to 1 m_coeffs = F.softmax(m_coeffs, dim=1) # Convert log variance to variance var_preds = safe_exp(log_var_preds) # Difference term -> (batch_size, num_mixtures, num_outputs) diff = y_true.unsqueeze(1) - mean_preds # # Exponent term -> (batch_size, num_mixtures, num_outputs) # exp_term = -0.5 * torch.square(diff) / var_preds # Compute log probabilities terms const_term = -0.5 * self.num_outputs * torch.log(torch.tensor(2 * torch.pi)) var_log_term = -0.5 * log_var_preds exp_term = -0.5 * torch.square(diff) / var_preds log_probs = const_term + var_log_term + exp_term # Sum over output dimensions to get log probabilities for each mixture # -> (batch_size, num_mixtures) log_probs = log_probs.sum(dim=2) # Compute mixture weighted log probabilities and add eps to prevent log(0) weighted_log_probs = log_probs + torch.log(m_coeffs + 1e-8) # Log-Sum-Exp trick for numerical stability -> (batch_size,) log_sum_exp = torch.logsumexp(weighted_log_probs, dim=1) # Compute final negative log-likelihood loss -> scalar loss = -torch.mean(log_sum_exp) return loss