import random
import warnings
from typing import Any, List, Optional
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.keras import Model
from tensorflow.keras.layers import Activation, Add, BatchNormalization, Dense, Dropout
from tensorflow.keras.regularizers import Regularizer
from tfmelt.nn_utils import get_initializer
[docs]
def get_kernel_divergence_fn(num_points):
"""
Get the kernel divergence function.
Args:
num_points (int): Number of points in kernel divergence.
Returns:
Callable: Kernel divergence function.
"""
def kernel_divergence_fn(q, p, _):
return tfp.distributions.kl_divergence(q, p) / (num_points * 1.0)
return kernel_divergence_fn
[docs]
class MELTBlock(Model):
"""
Base class for a MELT block. Defines a block of dense layers with optional
activation, dropout, and batch normalization. Forms the building block for
various neural network blocks.
Args:
node_list (List[int]): Number of nodes in each dense layer. The length of
the list determines the number of layers.
activation (str, optional): Activation function. If None,
no activation is applied (linear). Defaults to
"relu".
dropout (float, optional): Dropout rate (0-1). Defaults to None.
batch_norm (bool, optional): Apply batch normalization if True. Defaults
to False.
use_batch_renorm (bool, optional): Use batch renormalization. Defaults to False.
regularizer (Regularizer, optional): Kernel weights regularizer. Defaults
to None.
initializer (str, optional): String defining the kernel initializer. Defaults
to "glorot_uniform".
seed (int, optional): Random seed. Defaults to None.
**kwargs: Extra arguments passed to the base class.
"""
def __init__(
self,
node_list: List[int],
activation: Optional[str] = "relu",
dropout: Optional[float] = None,
batch_norm: Optional[bool] = False,
use_batch_renorm: Optional[bool] = False,
regularizer: Optional[Regularizer] = None,
initializer: Optional[str] = "glorot_uniform",
seed: Optional[int] = None,
**kwargs: Any,
):
super(MELTBlock, self).__init__(**kwargs)
self.node_list = node_list
self.activation = activation
self.dropout = dropout
self.batch_norm = batch_norm
self.use_batch_renorm = use_batch_renorm
self.regularizer = regularizer
self.seed = seed
# Get kernel initializer
if self.seed is None:
self.seed = random.randint(0, 2**32 - 1)
self.initializer = get_initializer(init_name=initializer, seed=self.seed)
# Number of layers in the block
self.num_layers = len(self.node_list)
# Validate dropout value
if self.dropout is not None:
assert 0 <= self.dropout < 1, "Dropout must be between 0 and 1"
# Activation layers
if self.activation:
self.activation_layers = [
Activation(self.activation, name=f"activation_{i}")
for i in range(self.num_layers)
]
# Optional dropout and batch norm layers
if self.dropout > 0:
self.dropout_layers = [
Dropout(self.dropout, name=f"dropout_{i}")
for i in range(self.num_layers)
]
if self.batch_norm:
self.batch_norm_layers = [
BatchNormalization(
renorm=self.use_batch_renorm,
renorm_clipping=(
{
"rmax": 3,
"rmin": 1 / 3,
"dmax": 5,
}
if self.use_batch_renorm
else None
),
name=f"batch_norm_{i}",
)
for i in range(self.num_layers)
]
# Create config dictionary for serialization
self.config = {
"node_list": self.node_list,
"activation": self.activation,
"dropout": self.dropout,
"batch_norm": self.batch_norm,
"regularizer": self.regularizer,
"initializer": self.initializer,
"seed": self.seed,
}
[docs]
def get_config(self):
"""Get the config dictionary"""
config = super(MELTBlock, self).get_config()
config.update(self.config)
return config
[docs]
@classmethod
def from_config(cls, config):
"""Create model from config dictionary"""
return cls(**config)
[docs]
class DenseBlock(MELTBlock):
"""
A DenseBlock consists of multiple dense layers with optional activation, dropout,
and batch normalization.
Args:
**kwargs: Extra arguments passed to the base class.
Raises:
AssertionError: If dropout is not within the range of [0, 1].
"""
def __init__(
self,
**kwargs: Any,
):
super(DenseBlock, self).__init__(**kwargs)
# Initialize dense layers
self.dense_layers = [
Dense(
node,
activation=None,
kernel_regularizer=self.regularizer,
kernel_initializer=self.initializer,
name=f"dense_{i}",
)
for i, node in enumerate(self.node_list)
]
[docs]
def call(self, inputs: tf.Tensor, training: bool = False):
"""Forward pass through the dense block."""
x = inputs
for i in range(self.num_layers):
# dense -> batch norm -> activation -> dropout
x = self.dense_layers[i](x, training=training)
x = (
self.batch_norm_layers[i](x, training=training)
if self.batch_norm
else x
)
x = self.activation_layers[i](x) if self.activation else x
x = self.dropout_layers[i](x, training=training) if self.dropout > 0 else x
return x
[docs]
class ResidualBlock(MELTBlock):
"""
A ResidualBlock consists of multiple dense layers with optional activation, dropout,
and batch normalization. Residual connections are added after each block of layers.
Args:
layers_per_block (int, optional): Number of layers per residual block. Defaults
to 2.
pre_activation (bool, optional): Apply activation before adding residual
connection. Defaults to False.
post_add_activation (bool, optional): Apply activation after adding residual
connection. Defaults to False.
**kwargs: Extra arguments passed to the base class.
"""
def __init__(
self,
layers_per_block: Optional[int] = 2,
pre_activation: Optional[bool] = False,
post_add_activation: Optional[bool] = False,
**kwargs: Any,
):
super(ResidualBlock, self).__init__(**kwargs)
self.layers_per_block = layers_per_block
self.pre_activation = pre_activation
self.post_add_activation = post_add_activation
# Warning if the number of layers is not divisible by layers_per_block
if self.num_layers % self.layers_per_block != 0:
warnings.warn(
f"Warning: Number of layers ({self.num_layers}) is not divisible by "
f"layers_per_block ({self.layers_per_block}), so the last block will "
f"have {self.num_layers % self.layers_per_block} layers."
)
# Initialize dense layers
self.dense_layers = [
Dense(
node,
activation=None,
kernel_regularizer=self.regularizer,
kernel_initializer=self.initializer,
name=f"dense_{i}",
)
for i, node in enumerate(self.node_list)
]
# Initialize Add layers
self.add_layers = [
Add(name=f"add_{i}")
# for i in range(self.num_layers // self.layers_per_block)
for i in range(
(self.num_layers + self.layers_per_block - 1) // self.layers_per_block
)
]
# Optional Add after activation layers
if self.post_add_activation:
self.post_add_activation_layers = [
Activation(self.activation, name=f"post_add_activation_{i}")
for i in range(self.num_layers // self.layers_per_block)
]
# Update config dictionary for serialization
self.config.update(
{
"layers_per_block": layers_per_block,
"pre_activation": pre_activation,
"post_add_activation": post_add_activation,
}
)
[docs]
def call(self, inputs: tf.Tensor, training: bool = False):
"""Forward pass through the residual block."""
x = inputs
for i in range(self.num_layers):
y = x
# dense -> (pre-activation) -> batch norm -> dropout -> (post-activation)
x = self.dense_layers[i](x, training=training)
x = self.activation_layers[i](x) if self.pre_activation else x
x = (
self.batch_norm_layers[i](x, training=training)
if self.batch_norm
else x
)
x = self.dropout_layers[i](x, training=training) if self.dropout > 0 else x
x = self.activation_layers[i](x) if not self.pre_activation else x
# Add residual connection when reaching the end of a block
if (i + 1) % self.layers_per_block == 0 or i == self.num_layers - 1:
x = self.add_layers[i // self.layers_per_block]([x, y])
x = (
self.post_add_activation_layers[i // self.layers_per_block](x)
if self.post_add_activation
else x
)
return x
[docs]
class BayesianBlock(MELTBlock):
"""
A BayesianBlock consists of multiple Bayesian dense layers with optional activation,
dropout, and batch normalization. The layers are implemented using the Flipout
variational layer.
Args:
num_points (int, optional): Number of Monte Carlo samples. Defaults to 1.
use_batch_renorm (bool, optional): Use batch renormalization. Defaults to True.
**kwargs: Extra arguments passed to the base class.
"""
def __init__(
self,
num_points: Optional[int] = 1,
use_batch_renorm: Optional[bool] = True,
**kwargs: Any,
):
super(BayesianBlock, self).__init__(**kwargs)
self.num_points = num_points
self.use_batch_renorm = use_batch_renorm
# Create kernel divergence function
kernel_divergence_fn = get_kernel_divergence_fn(self.num_points)
# Initialize Bayesian layers
self.bayesian_layers = [
tfp.layers.DenseFlipout(
node,
activation=None,
kernel_divergence_fn=kernel_divergence_fn,
activity_regularizer=self.regularizer,
name=f"bayesian_{i}",
)
for i, node in enumerate(self.node_list)
]
# Update config dictionary for serialization
self.config.update(
{
"num_points": num_points,
"use_batch_renorm": use_batch_renorm,
}
)
[docs]
def call(self, inputs: tf.Tensor, training: bool = False):
"""Forward pass through the Bayesian block."""
x = inputs
for i in range(self.num_layers):
# bayesian -> batch norm -> activation -> dropout
x = self.bayesian_layers[i](x, training=training)
x = (
self.batch_norm_layers[i](x, training=training)
if self.batch_norm
else x
)
x = self.activation_layers[i](x) if self.activation else x
x = self.dropout_layers[i](x, training=training) if self.dropout > 0 else x
return x
[docs]
class MixtureDensityOutput(Model):
"""
Output layer for a mixture density network.
Args:
num_mixtures (int): Number of mixture components.
num_outputs (int): Number of output nodes.
output_activation (str, optional): Activation function for the output layer.
Defaults to None.
initializer (str, optional): Kernel initializer. Defaults to "glorot_uniform".
regularizer (Regularizer, optional): Kernel regularizer. Defaults to None.
**kwargs: Extra arguments passed to the base class.
"""
def __init__(
self,
num_mixtures: int,
num_outputs: int,
output_activation: Optional[str] = None,
initializer: Optional[str] = "glorot_uniform",
seed: Optional[int] = None,
regularizer: Optional[Regularizer] = None,
**kwargs,
):
super(MixtureDensityOutput, self).__init__(**kwargs)
self.num_mixtures = num_mixtures
self.num_outputs = num_outputs
self.output_activation = output_activation
self.seed = seed
self.regularizer = regularizer
# Get kernel initializer
if self.seed is None:
self.seed = random.randint(0, 2**32 - 1)
self.initializer = get_initializer(init_name=initializer, seed=self.seed)
# Update config dictionary for serialization
self.config = {
"num_mixtures": self.num_mixtures,
"num_outputs": self.num_outputs,
"output_activation": self.output_activation,
"initializer": self.initializer,
"seed": self.seed,
"regularizer": self.regularizer,
}
self.mix_coeffs_layer = Dense(
self.num_mixtures,
activation="softmax",
kernel_initializer=self.initializer,
name="mix_coeffs",
)
self.mean_output_layer = Dense(
self.num_mixtures * self.num_outputs,
activation=self.output_activation,
kernel_initializer=self.initializer,
kernel_regularizer=self.regularizer,
name="mean_output",
)
self.log_var_output_layer = Dense(
self.num_mixtures * self.num_outputs,
activation=self.output_activation,
kernel_initializer=self.initializer,
kernel_regularizer=self.regularizer,
name="log_var_output",
)
[docs]
def call(self, x, training=False):
"""Forward pass through the output layer."""
m_coeffs = self.mix_coeffs_layer(x, training=training)
mean_output = self.mean_output_layer(x, training=training)
log_var_output = self.log_var_output_layer(x, training=training)
return tf.concat([m_coeffs, mean_output, log_var_output], axis=-1)
[docs]
def get_config(self):
"""Get the config dictionary"""
config = super(MixtureDensityOutput, self).get_config()
config.update(self.config)
return config
[docs]
@classmethod
def from_config(cls, config):
"""Create model from config dictionary"""
return cls(**config)
[docs]
class DefaultOutput(Model):
"""
Default output layer with a single dense layer.
Args:
num_outputs (int): Number of output nodes.
output_activation (str, optional): Activation function for the output layer.
Defaults to None.
initializer (str, optional): Kernel initializer. Defaults to "glorot_uniform".
regularizer (Regularizer, optional): Kernel regularizer. Defaults to None.
bayesian (bool, optional): Use Bayesian layer if True. Defaults to False.
num_points (int, optional): Number of samples. Defaults to 1.
**kwargs: Extra arguments passed to the base class.
"""
def __init__(
self,
num_outputs,
output_activation: Optional[str] = None,
initializer: Optional[str] = "glorot_uniform",
seed: Optional[int] = None,
regularizer: Optional[Regularizer] = None,
bayesian: Optional[bool] = False,
num_points: Optional[int] = 1,
**kwargs,
):
super(DefaultOutput, self).__init__(**kwargs)
self.num_outputs = num_outputs
self.output_activation = output_activation
self.seed = seed
self.regularizer = regularizer
self.bayesian = bayesian
self.num_points = num_points
# Set number of mixture components for compatibility
self.num_mixtures = 0
# Get kernel initializer
if self.seed is None:
self.seed = random.randint(0, 2**32 - 1)
self.initializer = get_initializer(init_name=initializer, seed=self.seed)
# Update config dictionary for serialization
self.config = {
"num_outputs": self.num_outputs,
"output_activation": self.output_activation,
"initializer": self.initializer,
"seed": self.seed,
"regularizer": self.regularizer,
"bayesian": self.bayesian,
"num_points": self.num_points,
"num_mixtures": self.num_mixtures,
}
# Create kernel divergence function
if bayesian:
kernel_divergence_fn = get_kernel_divergence_fn(self.num_points)
if bayesian:
self.output_layer = tfp.layers.DenseFlipout(
self.num_outputs,
kernel_divergence_fn=kernel_divergence_fn,
activation=self.output_activation,
activity_regularizer=self.regularizer,
name="bayesian_output",
)
else:
self.output_layer = Dense(
self.num_outputs,
activation=self.output_activation,
kernel_initializer=self.initializer,
kernel_regularizer=self.regularizer,
name="output",
)
[docs]
def call(self, x, training=False):
"""Forward pass through the output layer."""
return self.output_layer(x, training=training)
[docs]
def get_config(self):
"""Get the config dictionary"""
config = super(DefaultOutput, self).get_config()
config.update(self.config)
return config
[docs]
@classmethod
def from_config(cls, config):
"""Create model from config dictionary"""
return cls(**config)
[docs]
class BayesianAleatoricOutput(Model):
"""
Output layer for a Bayesian neural network with aleatoric uncertainty.
Args:
num_outputs (int): Number of output nodes.
num_points (int): Number of Monte Carlo samples.
regularizer (Regularizer, optional): Kernel regularizer. Defaults to None.
scale_epsilon (float, optional): Epsilon value for scale parameter. Defaults to
1e-3.
aleatoric_scale_factor (float, optional): Scaling factor for aleatoric
uncertainty. Defaults to 5e-2.
**kwargs: Extra arguments passed to the base class.
"""
def __init__(
self,
num_outputs: int,
num_points: int,
regularizer: Optional[Regularizer] = None,
scale_epsilon: Optional[float] = 1e-3,
aleatoric_scale_factor: Optional[float] = 5e-2,
**kwargs,
):
super(BayesianAleatoricOutput, self).__init__(**kwargs)
self.num_outputs = num_outputs
self.num_points = num_points
self.regularizer = regularizer
self.scale_epsilon = scale_epsilon
self.aleatoric_scale_factor = aleatoric_scale_factor
# Update config dictionary for serialization
self.config = {
"num_outputs": self.num_outputs,
"num_points": self.num_points,
"regularizer": self.regularizer,
"scale_epsilon": self.scale_epsilon,
"aleatoric_scale_factor": self.aleatoric_scale_factor,
}
# Create kernel divergence function
kernel_divergence_fn = get_kernel_divergence_fn(self.num_points)
self.pre_aleatoric_layer = tfp.layers.DenseFlipout(
2 * self.num_outputs,
kernel_divergence_fn=kernel_divergence_fn,
activation=None,
activity_regularizer=self.regularizer,
name="pre_aleatoric",
)
self.output_layer = tfp.layers.DistributionLambda(
lambda t: tfp.distributions.Normal(
loc=t[..., : self.num_outputs],
scale=self.scale_epsilon
+ tf.math.softplus(
self.aleatoric_scale_factor * t[..., self.num_outputs :]
),
),
name="distribution_output",
)
[docs]
def call(self, x, training=False):
"""Forward pass through the output layer."""
x = self.pre_aleatoric_layer(x, training=training)
return self.output_layer(x)
[docs]
def get_config(self):
"""Get the config dictionary"""
config = super(BayesianAleatoricOutput, self).get_config()
config.update(self.config)
return config
[docs]
@classmethod
def from_config(cls, config):
"""Create model from config dictionary"""
return cls(**config)