# -*- coding: utf-8 -*-
"""Custom tf layers."""
import logging
import numpy as np
import tensorflow as tf
logger = logging.getLogger(__name__)
[docs]class FlexiblePadding(tf.keras.layers.Layer):
"""Class to perform padding on tensors
"""
def __init__(self, paddings, mode='REFLECT'):
"""
Parameters
----------
paddings : int array
Integer array with shape [n,2] where n is the
rank of the tensor and elements give the number
of leading and trailing pads
mode : str
tf.pad() padding mode. Can be REFLECT, CONSTANT,
or SYMMETRIC
"""
super().__init__()
self.paddings = tf.constant(paddings)
self.rank = len(paddings)
self.mode = mode
[docs] def compute_output_shape(self, input_shape):
"""Computes output shape after padding
Parameters
----------
input_shape : tuple
shape of input tensor
Returns
-------
output_shape : tf.TensorShape
shape of padded tensor
"""
output_shape = [0] * self.rank
for d in range(self.rank):
output_shape[d] = sum(self.paddings[d]) + input_shape[d]
return tf.TensorShape(output_shape)
[docs] def call(self, x):
"""Calls the padding routine
Parameters
----------
x : tf.Tensor
tensor on which to perform padding
Returns
-------
x : tf.Tensor
padded tensor with shape given
by compute_output_shape
"""
return tf.pad(x, self.paddings,
mode=self.mode)
[docs]class ExpandDims(tf.keras.layers.Layer):
"""Layer to add an extra dimension to a tensor."""
def __init__(self, axis=3):
"""
Parameters
----------
axis : int
Target axis at which to expand the shape of the input. Default is
axis 3 based on creating a new temporal axis of the default
spatiotemporal shape of: (n_observations, n_spatial_0, n_spatial_1,
n_temporal, n_features)
"""
super().__init__()
self._axis = axis
[docs] def call(self, x):
"""Calls the expand dims operation
Parameters
----------
x : tf.Tensor
Input tensor
Returns
-------
x : tf.Tensor
Output tensor with an extra dimension based on the init axes arg
"""
return tf.expand_dims(x, axis=self._axis)
[docs]class TileLayer(tf.keras.layers.Layer):
"""Layer to tile (repeat) data across a given axis."""
def __init__(self, multiples):
"""
Parameters
----------
multiples : list
This is a list with the same length as number of dimensions in the
input tensor. Each entry in the list determines how many times to
tile each axis in the tensor.
"""
super().__init__()
self._mult = tf.constant(multiples, tf.int32)
[docs] def call(self, x):
"""Calls the tile operation
Parameters
----------
x : tf.Tensor
Input tensor
Returns
-------
x : tf.Tensor
Output tensor with the specified axes tiled into larger shapes
based on the multiples initialization argument.
"""
return tf.tile(x, self._mult)
[docs]class GaussianNoiseAxis(tf.keras.layers.Layer):
"""Layer to apply random noise along a given axis."""
def __init__(self, axis, mean=1, stddev=0.1):
"""
Parameters
----------
axis : int
Axis to apply random noise across. All other axis will have the
same noise. For example, for a 5D spatiotemporal tensor with axis=3
(the time axis), this layer will apply a single random number to
every unique index of axis=3.
mean : float
The mean of the normal distribution.
stddev : float
The standard deviation of the normal distribution.
"""
super().__init__()
self._axis = axis
self._rand_shape = None
self._mean = tf.constant(mean, dtype=tf.dtypes.float32)
self._stddev = tf.constant(stddev, dtype=tf.dtypes.float32)
[docs] def build(self, input_shape):
"""Custom implementation of the tf layer build method.
Sets the shape of the random noise along the specified axis
Parameters
----------
input_shape : tuple
Shape tuple of the input
"""
shape = np.ones(len(input_shape), dtype=np.int32)
shape[self._axis] = input_shape[self._axis]
self._rand_shape = tf.constant(shape, dtype=tf.dtypes.int32)
[docs] def call(self, x):
"""Calls the tile operation
Parameters
----------
x : tf.Tensor
Input tensor
Returns
-------
x : tf.Tensor
Output tensor with noise applied to the requested axis.
"""
rand_tensor = tf.random.normal(self._rand_shape,
mean=self._mean,
stddev=self._stddev,
dtype=tf.dtypes.float32)
return x * rand_tensor
[docs]class GaussianKernelInit2D(tf.keras.initializers.Initializer):
"""Convolutional kernel initializer that creates a symmetric 2D array with
a gaussian distribution. This can be used with Conv2D as a gaussian average
pooling layer if trainable=False
"""
def __init__(self, stdev=1):
"""
Parameters
----------
stdev : float
Standard deviation of the gaussian distribution defining the kernel
values
"""
self.stdev = stdev
[docs] def __call__(self, shape, dtype=tf.float32):
"""
Parameters
---------
shape : tuple
Shape of the input tensor, typically (y, x, n_features, n_obs)
dtype : None | tf.DType
Tensorflow datatype e.g., tf.float32
Returns
-------
kernel : tf.Tensor
Kernel tensor of shape (y, x, n_features, n_obs) for use in a
Conv2D layer.
"""
ax = np.linspace(-(shape[0] - 1) / 2., (shape[0] - 1) / 2., shape[0])
kernel = np.exp(-0.5 * np.square(ax) / np.square(self.stdev))
kernel = np.outer(kernel, kernel)
kernel = kernel / np.sum(kernel)
kernel = np.expand_dims(kernel, (2, 3))
kernel = np.repeat(kernel, shape[2], axis=2)
kernel = np.repeat(kernel, shape[3], axis=3)
kernel = tf.convert_to_tensor(kernel, dtype=dtype)
return kernel
[docs]class FlattenAxis(tf.keras.layers.Layer):
"""Layer to flatten an axis from a 5D spatiotemporal Tensor into axis-0
observations."""
def __init__(self, axis=3):
"""
Parameters
----------
axis : int
Target axis that holds the dimension to be flattened into the
axis-0 dimension. Default is axis 3 based on flatteneing the
temporal axis of the default spatiotemporal shape of:
(n_observations, n_spatial_0, n_spatial_1, n_temporal, n_features)
"""
super().__init__()
self._axis = axis
@staticmethod
def _check_shape(input_shape):
"""Assert that the shape of the input tensor is the expected 5D
spatiotemporal shape
Parameters
----------
input_shape : tuple
Shape tuple of the input
"""
msg = ('Input to FlattenAxis must be 5D with dimensions: '
'(n_observations, n_spatial_0, n_spatial_1, n_temporal, '
'n_features), but received shape: {}'.format(input_shape))
assert len(input_shape) == 5, msg
[docs] def call(self, x):
"""Calls the flatten axis operation
Parameters
----------
x : tf.Tensor
5D spatiotemporal tensor with dimensions:
(n_observations, n_spatial_0, n_spatial_1, n_temporal, n_features)
Returns
-------
x : tf.Tensor
4D spatiotemporal tensor with target axis flattened into axis 0
"""
self._check_shape(x.shape)
return tf.concat(tf.unstack(x, axis=self._axis), axis=0)
[docs]class SpatialExpansion(tf.keras.layers.Layer):
"""Class to expand the spatial dimensions of tensors with shape:
(n_observations, n_spatial_0, n_spatial_1, n_features)
"""
def __init__(self, spatial_mult=1):
"""
Parameters
----------
spatial_multiplier : int
Number of times to multiply the spatial dimensions. Note that the
spatial expansion is an un-packing of the feature dimension. For
example, if the input layer has shape (123, 5, 5, 16) with
multiplier=2 the output shape will be (123, 10, 10, 4). The
input feature dimension must be divisible by the spatial multiplier
squared.
"""
super().__init__()
self._spatial_mult = int(spatial_mult)
@staticmethod
def _check_shape(input_shape):
"""Assert that the shape of the input tensor is the expected 4D
spatiotemporal shape
Parameters
----------
input_shape : tuple
Shape tuple of the input
"""
msg = ('Input to SpatialExpansion must be 4D with dimensions: '
'(n_observations, n_spatial_0, n_spatial_1, n_features), '
'but received shape: {}'.format(input_shape))
assert len(input_shape) == 4, msg
[docs] def build(self, input_shape):
"""Custom implementation of the tf layer build method.
Parameters
----------
input_shape : tuple
Shape tuple of the input
"""
self._check_shape(input_shape)
def _spatial_expand(self, x):
"""Expand the two spatial dimensions (axis=1,2) of a 4D tensor using
data from the last axes"""
check_shape = x.shape[-1] % self._spatial_mult**2
if check_shape != 0:
msg = ('Spatial expansion of factor {} is being attempted on '
'input tensor of shape {}, but the last dimension of the '
'input tensor ({}) must be divisible by the spatial '
'factor squared ({}).'
.format(self._spatial_mult, x.shape, x.shape[-1],
self._spatial_mult**2))
logger.error(msg)
raise RuntimeError(msg)
return tf.nn.depth_to_space(x, self._spatial_mult)
[docs] def call(self, x):
"""Call the custom SpatialExpansion layer
Parameters
----------
x : tf.Tensor
4D spatial tensor
(n_observations, n_spatial_0, n_spatial_1, n_features)
Returns
-------
x : tf.Tensor
4D spatiotemporal tensor with axes 1,2 expanded (if spatial_mult>1)
"""
self._check_shape(x.shape)
if self._spatial_mult > 1:
x = self._spatial_expand(x)
return x
[docs]class SpatioTemporalExpansion(tf.keras.layers.Layer):
"""Class to expand the spatiotemporal dimensions of tensors with shape:
(n_observations, n_spatial_0, n_spatial_1, n_temporal, n_features)
"""
def __init__(self, spatial_mult=1, temporal_mult=1,
temporal_method='nearest', t_roll=0):
"""
Parameters
----------
spatial_multiplier : int
Number of times to multiply the spatial dimensions. Note that the
spatial expansion is an un-packing of the feature dimension. For
example, if the input layer has shape (123, 5, 5, 24, 16) with
multiplier=2 the output shape will be (123, 10, 10, 24, 4). The
input feature dimension must be divisible by the spatial multiplier
squared.
temporal_multiplier : int
Number of times to multiply the temporal dimension. For example,
if the input layer has shape (123, 5, 5, 24, 2) with multiplier=2
the output shape will be (123, 5, 5, 48, 2).
temporal_method : str
Interpolation method for tf.image.resize(). Can also be
"depth_to_time" for an operation similar to tf.nn.depth_to_space
where the feature axis is unpacked into the temporal axis.
t_roll : int
Option to roll the temporal axis after expanding. When using
temporal_method="depth_to_time", the default (t_roll=0) will
add temporal steps after the input steps such that if input
temporal shape is 3 and the temporal_mult is 24x, the output will
have the original timesteps at idt=0,24,48 but if t_roll=12, the
output will have the original timesteps at idt=12,36,60
"""
super().__init__()
self._spatial_mult = int(spatial_mult)
self._temporal_mult = int(temporal_mult)
self._temporal_meth = temporal_method
self._t_roll = t_roll
@staticmethod
def _check_shape(input_shape):
"""Assert that the shape of the input tensor is the expected 5D
spatiotemporal shape
Parameters
----------
input_shape : tuple
Shape tuple of the input
"""
msg = ('Input to SpatioTemporalExpansion must be 5D with dimensions: '
'(n_observations, n_spatial_0, n_spatial_1, n_temporal, '
'n_features), but received shape: {}'.format(input_shape))
assert len(input_shape) == 5, msg
[docs] def build(self, input_shape):
"""Custom implementation of the tf layer build method.
Parameters
----------
input_shape : tuple
Shape tuple of the input
"""
self._check_shape(input_shape)
def _temporal_expand(self, x):
"""Expand the temporal dimension (axis=3) of a 5D tensor"""
if self._temporal_meth == 'depth_to_time':
check_shape = x.shape[-1] % self._temporal_mult
if check_shape != 0:
msg = ('Temporal expansion of factor {} is being attempted on '
'input tensor of shape {}, but the last dimension of '
'the input tensor ({}) must be divisible by the '
'temporal factor ({}).'
.format(self._temporal_mult, x.shape, x.shape[-1],
self._temporal_mult))
logger.error(msg)
raise RuntimeError(msg)
shape = (x.shape[0], x.shape[1], x.shape[2],
x.shape[3] * self._temporal_mult,
x.shape[4] // self._temporal_mult)
out = tf.reshape(x, shape)
out = tf.roll(out, self._t_roll, axis=3)
else:
temp_expand_shape = tf.stack([x.shape[2],
x.shape[3] * self._temporal_mult])
out = []
for x_unstack in tf.unstack(x, axis=1):
out.append(tf.image.resize(x_unstack, temp_expand_shape,
method=self._temporal_meth))
out = tf.stack(out, axis=1)
return out
def _spatial_expand(self, x):
"""Expand the two spatial dimensions (axis=1,2) of a 5D tensor using
data from the last axes"""
check_shape = x.shape[-1] % self._spatial_mult**2
if check_shape != 0:
msg = ('Spatial expansion of factor {} is being attempted on '
'input tensor of shape {}, but the last dimension of the '
'input tensor ({}) must be divisible by the spatial '
'factor squared ({}).'
.format(self._spatial_mult, x.shape, x.shape[-1],
self._spatial_mult**2))
logger.error(msg)
raise RuntimeError(msg)
out = []
for x_unstack in tf.unstack(x, axis=3):
out.append(tf.nn.depth_to_space(x_unstack, self._spatial_mult))
return tf.stack(out, axis=3)
[docs] def call(self, x):
"""Call the custom SpatioTemporalExpansion layer
Parameters
----------
x : tf.Tensor
5D spatiotemporal tensor.
Returns
-------
x : tf.Tensor
5D spatiotemporal tensor with axes 1,2 expanded (if spatial_mult>1)
and axes 3 expanded (if temporal_mult>1).
"""
self._check_shape(x.shape)
if self._temporal_mult > 1:
x = self._temporal_expand(x)
if self._spatial_mult > 1:
x = self._spatial_expand(x)
return x
[docs]class SkipConnection(tf.keras.layers.Layer):
"""Custom layer to implement a skip connection. This layer should be
initialized and referenced in a layer list by the same name as both the
skip start and skip end.
"""
def __init__(self, name):
"""
Parameters
----------
name : str
Unique string identifier of the skip connection. The skip endpoint
should have the same name.
"""
super().__init__(name=name)
self._cache = None
[docs] def call(self, x):
"""Call the custom SkipConnection layer
Parameters
----------
x : tf.Tensor
Input tensor.
Returns
-------
x : tf.Tensor
Output tensor. If this is the skip start, the input will be cached
and returned without manipulation. If this is the skip endpoint,
the output will be the input x added to the tensor cached at the
skip start.
"""
if self._cache is None:
self._cache = x
return x
else:
try:
out = tf.add(x, self._cache)
except Exception as e:
msg = ('Could not add SkipConnection "{}" data cache of '
'shape {} to input of shape {}.'
.format(self._name, self._cache.shape, x.shape))
logger.error(msg)
raise RuntimeError(msg) from e
else:
self._cache = None
return out
[docs]class SqueezeAndExcitation(tf.keras.layers.Layer):
"""Custom layer for squeeze and excitation block for convolutional networks
Note that this is only set up to take a channels-last conv output
References
----------
1. Hu, Jie, et al. Squeeze-and-Excitation Networks. arXiv:1709.01507,
arXiv, 16 May 2019, http://arxiv.org/abs/1709.01507.
2. Pröve, Paul-Louis. “Squeeze-and-Excitation Networks.” Medium, 18 Oct.
2017,
https://towardsdatascience.com/squeeze-and-excitation-networks-9ef5e71eacd7
"""
def __init__(self, ratio=16):
"""
Parameters
----------
ratio : int
Number of convolutional channels/filters divided by the number of
dense connections in the SE block.
"""
super().__init__()
self._ratio = ratio
self._n_channels = None
self._dense_units = None
self._hidden_layers = None
[docs] def build(self, input_shape):
"""Build the SqueezeAndExcitation layer based on an input shape
Parameters
----------
input_shape : tuple
Shape tuple of the input tensor
"""
self._n_channels = input_shape[-1]
self._dense_units = int(np.ceil(self._n_channels / self._ratio))
if len(input_shape) == 4:
pool_layer = tf.keras.layers.GlobalAveragePooling2D()
elif len(input_shape) == 5:
pool_layer = tf.keras.layers.GlobalAveragePooling3D()
else:
msg = ('SqueezeAndExcitation layer can only accept 4D or 5D data '
'for image or video input but received input shape: {}'
.format(input_shape))
logger.error(msg)
raise RuntimeError(msg)
self._hidden_layers = [
pool_layer,
tf.keras.layers.Dense(self._dense_units, activation='relu'),
tf.keras.layers.Dense(self._n_channels, activation='sigmoid'),
tf.keras.layers.Multiply()]
[docs] def call(self, x):
"""Call the custom SqueezeAndExcitation layer
Parameters
----------
x : tf.Tensor
Input tensor.
Returns
-------
x : tf.Tensor
Output tensor, this is the squeeze-and-excitation weights
multiplied by the original input tensor x
"""
t_in = x
for layer in self._hidden_layers[:-1]:
x = layer(x)
# multiply layer
x = self._hidden_layers[-1]([t_in, x])
return x
[docs]class FNO(tf.keras.layers.Layer):
"""Custom layer for fourier neural operator block
Note that this is only set up to take a channels-last input
References
----------
1. FourCastNet: A Global Data-driven High-resolution Weather Model using
Adaptive Fourier Neural Operators. http://arxiv.org/abs/2202.11214
2. Adaptive Fourier Neural Operators: Efficient Token Mixers for
Transformers. http://arxiv.org/abs/2111.13587
"""
def __init__(self, filters, sparsity_threshold=0.5, activation='relu'):
"""
Parameters
----------
filters : int
Number of dense connections in the FNO block.
sparsity_threshold : float
Parameter to control sparsity and shrinkage in the softshrink
activation function following the MLP layers.
activation : str
Activation function used in MLP layers.
"""
super().__init__()
self._filters = filters
self._fft_layer = None
self._ifft_layer = None
self._mlp_layers = None
self._activation = activation
self._n_channels = None
self._perms_in = None
self._perms_out = None
self._lambd = sparsity_threshold
def _softshrink(self, x):
"""Softshrink activation function
https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html
"""
values_below_lower = tf.where(x < -self._lambd, x + self._lambd, 0)
values_above_upper = tf.where(self._lambd < x, x - self._lambd, 0)
return values_below_lower + values_above_upper
def _fft(self, x):
"""Apply needed transpositions and fft operation."""
x = tf.transpose(x, perm=self._perms_in)
x = self._fft_layer(tf.cast(x, tf.complex64))
x = tf.transpose(x, perm=self._perms_out)
return x
def _ifft(self, x):
"""Apply needed transpositions and ifft operation."""
x = tf.transpose(x, perm=self._perms_in)
x = self._ifft_layer(tf.cast(x, tf.complex64))
x = tf.transpose(x, perm=self._perms_out)
return x
[docs] def build(self, input_shape):
"""Build the FNO layer based on an input shape
Parameters
----------
input_shape : tuple
Shape tuple of the input tensor
"""
self._n_channels = input_shape[-1]
dims = list(range(len(input_shape)))
self._perms_in = [dims[-1], *dims[:-1]]
self._perms_out = [*dims[1:], dims[0]]
if len(input_shape) == 4:
self._fft_layer = tf.signal.fft2d
self._ifft_layer = tf.signal.ifft2d
elif len(input_shape) == 5:
self._fft_layer = tf.signal.fft3d
self._ifft_layer = tf.signal.ifft3d
else:
msg = ('FNO layer can only accept 4D or 5D data '
'for image or video input but received input shape: {}'
.format(input_shape))
logger.error(msg)
raise RuntimeError(msg)
self._mlp_layers = [
tf.keras.layers.Dense(self._filters, activation=self._activation),
tf.keras.layers.Dense(self._n_channels)]
def _mlp_block(self, x):
"""Run mlp layers on input"""
for layer in self._mlp_layers:
x = layer(x)
return x
[docs] def call(self, x):
"""Call the custom FourierNeuralOperator layer
Parameters
----------
x : tf.Tensor
Input tensor.
Returns
-------
x : tf.Tensor
Output tensor, this is the FNO weights added to the original input
tensor.
"""
t_in = x
x = self._fft(x)
x = self._mlp_block(x)
x = self._softshrink(x)
x = self._ifft(x)
x = tf.cast(x, dtype=t_in.dtype)
return x + t_in
[docs]class Sup3rAdder(tf.keras.layers.Layer):
"""Layer to add high-resolution data to a sup3r model in the middle of a
super resolution forward pass."""
def __init__(self, name=None):
"""
Parameters
----------
name : str | None
Unique str identifier of the adder layer. Usually the name of the
hi-resolution feature used in the addition.
"""
super().__init__(name=name)
[docs] def call(self, x, hi_res_adder):
"""Adds hi-resolution data to the input tensor x in the middle of a
sup3r resolution network.
Parameters
----------
x : tf.Tensor
Input tensor
hi_res_adder : tf.Tensor | np.ndarray
This should be a 4D array for spatial enhancement model or 5D array
for a spatiotemporal enhancement model (obs, spatial_1, spatial_2,
(temporal), features) that can be added to x.
Returns
-------
x : tf.Tensor
Output tensor with the hi_res_adder added to x.
"""
return x + hi_res_adder
[docs]class Sup3rConcat(tf.keras.layers.Layer):
"""Layer to concatenate a high-resolution feature to a sup3r model in the
middle of a super resolution forward pass."""
def __init__(self, name=None):
"""
Parameters
----------
name : str | None
Unique str identifier for the concat layer. Usually the name of the
hi-resolution feature used in the concatenation.
"""
super().__init__(name=name)
[docs] def call(self, x, hi_res_feature):
"""Concatenates a hi-resolution feature to the input tensor x in the
middle of a sup3r resolution network.
Parameters
----------
x : tf.Tensor
Input tensor
hi_res_feature : tf.Tensor | np.ndarray
This should be a 4D array for spatial enhancement model or 5D array
for a spatiotemporal enhancement model (obs, spatial_1, spatial_2,
(temporal), features) that can be concatenated to x.
Returns
-------
x : tf.Tensor
Output tensor with the hi_res_feature added to x.
"""
return tf.concat((x, hi_res_feature), axis=-1)
[docs]class FunctionalLayer(tf.keras.layers.Layer):
"""Custom layer to implement the tensorflow layer functions (e.g., add,
subtract, multiply, maximum, and minimum) with a constant value. These
cannot be implemented in phygnn as normal layers because they need to
operate on two tensors of equal shape."""
def __init__(self, name, value):
"""
Parameters
----------
name : str
Name of the tensorflow layer function to be implemented, options
are (all lower-case): add, subtract, multiply, maximum, and minimum
value : float
Constant value to use in the function operation
"""
options = ('add', 'subtract', 'multiply', 'maximum', 'minimum')
msg = (f'FunctionalLayer input `name` must be one of "{options}" '
f'but received "{name}"')
assert name in options, msg
super().__init__(name=name)
self.value = value
self.fun = getattr(tf.keras.layers, self.name)
[docs] def call(self, x):
"""Operates on x with the specified function
Parameters
----------
x : tf.Tensor
Input tensor
Returns
-------
x : tf.Tensor
Output tensor operated on by the specified function
"""
const = tf.constant(value=self.value, shape=x.shape, dtype=x.dtype)
return self.fun((x, const))