Source code for ptmelt.losses

import torch
import torch.nn.functional as F


[docs] def safe_exp(x): """Prevents overflow by clipping input range to reasonable values.""" x = torch.clamp(x, min=-10, max=10) return torch.exp(x)
[docs] class MixtureDensityLoss(torch.nn.Module): """ Custom loss function for Mixture Density Network (MDN). Args: num_mixtures (int): Number of mixture components. num_outputs (int): Number of output dimensions. """ def __init__(self, num_mixtures, num_outputs, mse_weight=1.0, reduction="mean"): super(MixtureDensityLoss, self).__init__() self.num_mixtures = num_mixtures self.num_outputs = num_outputs self.mse_weight = mse_weight assert reduction in ( "mean", "sum", "none", ), "Reduction must be 'mean', 'sum', or 'none'" self.reduction = reduction
[docs] def forward(self, y_pred, y_true): # Extract the mixture coefficients, means, and log-variances end_mixture = self.num_mixtures end_mean = end_mixture + self.num_mixtures * self.num_outputs end_log_var = end_mean + self.num_mixtures * self.num_outputs # coefficients -> (batch_size, num_mixtures) m_coeffs = y_pred[:, :end_mixture] # means -> (batch_size, num_mixtures * num_outputs) mean_preds = y_pred[:, end_mixture:end_mean] # log variances -> (batch_size, num_mixtures * num_outputs) log_var_preds = y_pred[:, end_mean:end_log_var] # Reshape mean predictions -> (batch_size, num_mixtures, num_outputs) mean_preds = mean_preds.view(-1, self.num_mixtures, self.num_outputs) # Reshape log variance predictions -> (batch_size, num_mixtures, num_outputs) log_var_preds = log_var_preds.view(-1, self.num_mixtures, self.num_outputs) log_var_preds = torch.clamp(log_var_preds, min=-10.0, max=10.0) # Ensure mixture coefficients sum to 1 m_coeffs = F.softmax(m_coeffs, dim=1) # Convert log variance to variance var_preds = safe_exp(log_var_preds) # Difference term -> (batch_size, num_mixtures, num_outputs) diff = y_true.unsqueeze(1) - mean_preds # # Exponent term -> (batch_size, num_mixtures, num_outputs) # exp_term = -0.5 * torch.square(diff) / var_preds # Compute log probabilities terms const_term = -0.5 * self.num_outputs * torch.log(torch.tensor(2 * torch.pi)) var_log_term = -0.5 * log_var_preds exp_term = -0.5 * torch.square(diff) / torch.clamp(var_preds, min=1e-10) log_probs = const_term + var_log_term + exp_term # Sum over output dimensions to get log probabilities for each mixture # -> (batch_size, num_mixtures) log_probs = log_probs.sum(dim=2) # Compute mixture weighted log probabilities and add eps to prevent log(0) # weighted_log_probs = log_probs + torch.log(m_coeffs + 1e-8) weighted_log_probs = log_probs + torch.log(torch.clamp(m_coeffs, min=1e-8)) # Log-Sum-Exp trick for numerical stability -> (batch_size,) log_sum_exp = torch.logsumexp(weighted_log_probs, dim=1) # Compute final negative log-likelihood loss -> scalar # loss = -torch.mean(log_sum_exp) loss = log_sum_exp # add in the mse as well if self.mse_weight > 0.0: mix_mean = (m_coeffs.unsqueeze(-1) * mean_preds).sum(dim=1) mse_loss = F.mse_loss(mix_mean, y_true, reduction="none").mean(dim=-1) loss += self.mse_weight * mse_loss # Apply reduction to the loss if self.reduction == "mean": loss = -torch.mean(loss) elif self.reduction == "sum": loss = -torch.sum(loss) # else no reduction, return the full loss tensor return loss