from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, kl_divergence
[docs]
class MELTBayesianDenseFlipOut(nn.Module):
"""
Custom Bayesian Layer for PT-MELT.
"""
def __init__(
self,
in_features: int,
out_features: int,
prior_mean: float = 0.0,
prior_std: float = 10.0,
perturbation_type: str = "additive",
seed: Optional[int] = None,
):
"""
Initialize the Bayesian layer using a Dense Flipout type approach.
Implementation based on the paper:
"Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches"
by Wen et al. (2018).
https://doi.org/10.48550/arXiv.1803.04386
Perturbations are of the form: W = W_mu + delta_W where delta_W can take on
different forms depending on the perturbation type.
Additive perturbations are formulated like: W = W_mu + W_sigma * epsilon
Multiplicative perturbations are formulated like: W = W_mu * (1 + W_sigma * epsilon)
"""
super(MELTBayesianDenseFlipOut, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.prior_mean = prior_mean
self.prior_std = prior_std
self.perturbation_type = perturbation_type
self.seed = seed
# Initialize learnable parameters for the posterior (zeros? or random?)
self.weight_mu = nn.Parameter(torch.zeros(out_features, in_features))
self.weight_rho = nn.Parameter(torch.zeros(out_features, in_features))
self.bias_mu = nn.Parameter(torch.zeros(out_features))
self.bias_rho = nn.Parameter(torch.zeros(out_features))
# Initialize the parameters
torch.manual_seed(self.seed) if self.seed is not None else None
nn.init.xavier_uniform_(self.weight_mu)
nn.init.zeros_(self.bias_mu)
nn.init.constant_(self.weight_rho, -3.0)
nn.init.constant_(self.bias_rho, -3.0)
# Define prior distributions
self.prior = Normal(self.prior_mean, self.prior_std)
self.posterior_weight = None
self.posterior_bias = None
[docs]
def forward(self, input: torch.Tensor):
"""
Perform the forward pass of the Bayesian Linear Layer.
Flipout weights are perturbed like: W = W_mu + delta_W
delta_W has a component shared across the entire mini-batch and a component
that is unique to each input sample.
"""
batch_size = input.size(0)
# Convert rho to sigma
weight_sigma = F.softplus(self.weight_rho)
bias_sigma = F.softplus(self.bias_rho)
# Compute the mini-batch delta for weights and biases
weight_epsilon = torch.randn_like(self.weight_mu)
bias_epsilon = torch.randn(batch_size, self.out_features, device=input.device)
# delta_W shared across the mini-batch
delta_W = weight_sigma * weight_epsilon
delta_b = bias_sigma * bias_epsilon
# delta_W unique to each input sample by Flipout perturbations
row_sign = torch.sign(
torch.randn(batch_size, self.out_features, device=input.device)
)
col_sign = torch.sign(
torch.randn(batch_size, self.in_features, device=input.device)
)
# Compute the perturbed weights
pert_matrix = row_sign.unsqueeze(2) * col_sign.unsqueeze(1)
if self.perturbation_type == "additive":
perturbed_weights = self.weight_mu + delta_W * pert_matrix
elif self.perturbation_type == "multiplicative":
perturbed_weights = self.weight_mu * (1 + delta_W * pert_matrix)
perturbed_bias = self.bias_mu + delta_b
# Torch-based efficient way of handling the matrix multiplication
output = torch.einsum("bij,bj->bi", perturbed_weights, input) + perturbed_bias
return output
def _kl_divergence(self):
posterior_weight = Normal(self.weight_mu, F.softplus(self.weight_rho))
posterior_bias = Normal(self.bias_mu, F.softplus(self.bias_rho))
prior = Normal(self.prior_mean, self.prior_std)
# Compute KL divergence for weights and biases
kl_weights = kl_divergence(posterior_weight, prior).sum()
kl_biases = kl_divergence(posterior_bias, prior).sum()
return kl_weights + kl_biases
[docs]
class MELTBatchNorm(nn.Module):
"""
Custom Batch Normalization Layer for PT-MELT.
Supports implementation of different types of moving averages for the batch norm
statistics.
Args:
num_features (int): Number of features in the input tensor.
eps (float, optional): Small value to avoid division by zero.
Defaults to 1e-5.
momentum (float, optional): Momentum for moving average. Defaults to 0.1.
affine (bool, optional): Apply affine transformation. Defaults to True.
track_running_stats (bool, optional): Track running statistics. Defaults to
True.
average_type (str, optional): Type of moving average. Defaults to "ema".
"""
def __init__(
self,
num_features: int,
eps: Optional[float] = 1e-5,
momentum: Optional[float] = 0.1,
affine: Optional[bool] = True,
track_running_stats: Optional[bool] = True,
average_type: Optional[str] = "ema",
):
# TODO: Check that all features of PyTorch BatchNorm are implemented.
super(MELTBatchNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.average_type = average_type
# Initialize Parameters
if self.affine:
self.weight = nn.Parameter(torch.Tensor(num_features))
self.bias = nn.Parameter(torch.Tensor(num_features))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
# Initialize Running Statistics
if self.track_running_stats:
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
self.register_buffer(
"num_batches_tracked", torch.tensor(0, dtype=torch.long)
)
else:
self.register_parameter("running_mean", None)
self.register_parameter("running_var", None)
self.register_parameter("num_batches_tracked", None)
self.reset_parameters()
[docs]
def reset_parameters(self):
"""Reset the parameters of the layer."""
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.fill_(1)
self.bias.data.zero_()
else:
self.weight = None
self.bias = None
[docs]
def forward(self, input: torch.Tensor):
"""
Perform the forward pass of the batch normalization layer.
Args:
input (torch.Tensor): Input tensor to be normalized.
"""
# Calculate Batch Norm Statistics
if self.training:
mean = input.mean(dim=0)
var = input.var(dim=0, unbiased=False)
else:
mean = self.running_mean
var = self.running_var
# Update Running Statistics
if self.track_running_stats and self.average_type == "ema":
self.running_mean = (
1 - self.momentum
) * self.running_mean + self.momentum * mean
# self.running_mean.mul_(1 - self.momentum).add_(mean, alpha=self.momentum)
self.running_var = (
1 - self.momentum
) * self.running_var + self.momentum * var
# self.running_var.mul_(1 - self.momentum).add_(var, alpha=self.momentum)
elif self.track_running_stats and self.average_type == "simple":
self.running_mean = mean
self.running_var = var
# Normalize
if self.training:
input = (input - mean) / (var + self.eps).sqrt()
else:
input = (input - self.running_mean) / (self.running_var + self.eps).sqrt()
# Scale and Shift
if self.affine:
input = input * self.weight + self.bias
return input
[docs]
class MELTBatchRenorm(MELTBatchNorm):
"""
Custom Batch Renormalization Layer for PT-MELT.
Supports implementation of different types of moving averages for the batch norm
statistics.
Args:
num_features (int): Number of features in the input tensor.
eps (float, optional): Small value to avoid division by zero. Defaults to
1e-5.
momentum (float, optional): Momentum for moving average. Defaults to 0.1.
affine (bool, optional): Apply affine transformation. Defaults to True.
track_running_stats (bool, optional): Track running statistics. Defaults to
True.
average_type (str, optional): Type of moving average. Defaults to "ema".
rmax (float, optional): Maximum value for r. Defaults to 1.0.
dmax (float, optional): Maximum value for d. Defaults to 0.0.
"""
def __init__(
self,
num_features: int,
eps: Optional[float] = 1e-5,
momentum: Optional[float] = 0.1,
affine: Optional[bool] = True,
track_running_stats: Optional[bool] = True,
average_type: Optional[str] = "ema",
rmax: Optional[float] = 1.0,
dmax: Optional[float] = 0.0,
):
# TODO: Verify accuracy of renorm implementation.
super().__init__(
num_features, eps, momentum, affine, track_running_stats, average_type
)
self.register_buffer("rmax", torch.tensor(rmax))
self.register_buffer("dmax", torch.tensor(dmax))
self.register_buffer("r", torch.ones(1))
self.register_buffer("d", torch.zeros(1))
[docs]
def forward(self, input: torch.Tensor):
"""Perform the forward pass of the batch renormalization layer."""
# Calculate Batch Norm Statistics
if self.training:
mean = input.mean(dim=0)
var = input.var(dim=0, unbiased=False)
std = torch.sqrt(var + self.eps)
r = std / (self.running_var.sqrt() + self.eps)
r = torch.clamp(r, 1 / self.rmax, self.rmax)
d = (mean - self.running_mean) / (self.running_var.sqrt() + self.eps)
d = torch.clamp(d, -self.dmax, self.dmax)
self.r = r
self.d = d
else:
mean = self.running_mean
var = self.running_var
# Update Running Statistics
if self.track_running_stats and self.average_type == "ema":
self.running_mean = (
1 - self.momentum
) * self.running_mean + self.momentum * mean
self.running_var = (
1 - self.momentum
) * self.running_var + self.momentum * var
elif self.track_running_stats and self.average_type == "simple":
self.running_mean = mean
self.running_var = var
# Apply Batch Renormalization
if self.training:
x_hat = (input - mean) * r / std + d
else:
x_hat = (input - self.running_mean) / torch.sqrt(
self.running_var + self.eps
)
return self.weight * x_hat + self.bias