import torch
[docs]
def safe_exp(x):
"""Prevents overflow by clipping input range to reasonable values."""
x = torch.clamp(x, min=-20, max=20)
return torch.exp(x)
[docs]
class MixtureDensityLoss(torch.nn.Module):
"""
Custom loss function for a Gaussian mixture model.
Args:
num_mixtures (int): Number of mixture components.
num_outputs (int): Number of output dimensions.
"""
def __init__(self, num_mixtures, num_outputs):
super(MixtureDensityLoss, self).__init__()
self.num_mixtures = num_mixtures
self.num_outputs = num_outputs
[docs]
def forward(self, y_pred, y_true):
# NOTE: the order of the parameters is reversed compared to Keras and TensorFlow
# Extract the mixture coefficients, means, and log-variances
end_mixture = self.num_mixtures
end_mean = end_mixture + self.num_mixtures * self.num_outputs
end_log_var = end_mean + self.num_mixtures * self.num_outputs
m_coeffs = y_pred[:, :end_mixture]
mean_preds = y_pred[:, end_mixture:end_mean]
log_var_preds = y_pred[:, end_mean:end_log_var]
# Reshape to ensure same shape as y_true replicated across mixtures
mean_preds = mean_preds.view(-1, self.num_mixtures, self.num_outputs)
log_var_preds = log_var_preds.view(-1, self.num_mixtures, self.num_outputs)
# Calculate the Gaussian probability density function for each component
const_term = -0.5 * self.num_outputs * torch.log(torch.tensor(2 * torch.pi))
inv_sigma_log = -0.5 * log_var_preds
exp_term = (
-0.5
* torch.square(y_true.unsqueeze(1) - mean_preds)
/ safe_exp(log_var_preds)
)
# form the log probabilities
log_probs = const_term + inv_sigma_log + exp_term
# Calculate the log likelihood
weighted_log_probs = log_probs + torch.log(m_coeffs.unsqueeze(-1))
log_sum_exp = torch.logsumexp(weighted_log_probs, dim=1)
# Compute the log likelihood loss
log_likelihood = torch.mean(log_sum_exp)
# Return the negative log likelihood
return -log_likelihood