networks module

Helper methods to build and run neural networks. Code adopted from https://github.com/HumanCompatibleAI/imitation.git

class networks.BaseNorm(num_features: int, eps: float = 1e-05)

Bases: torch.nn.modules.module.Module, abc.ABC

Base class for layers that try to normalize the input to mean 0 and variance 1.

Similar to BatchNorm, LayerNorm, etc. but whereas they only use statistics from the current batch at train time, we use statistics from all batches.

count: torch.Tensor
forward(x: torch.Tensor) torch.Tensor

Updates statistics if in training mode. Returns normalized x.

reset_running_stats() None

Resets running stats to defaults, yielding the identity transformation.

running_mean: torch.Tensor
running_var: torch.Tensor
abstract classmethod update_stats(batch: torch.Tensor) None

Update self.running_mean, self.running_var and self.count.

class networks.EMANorm(num_features: int, decay: float = 0.99, eps: float = 1e-05)

Bases: networks.BaseNorm

Similar to RunningNorm but uses an exponential weighting.

count: torch.Tensor
running_mean: torch.Tensor
running_var: torch.Tensor
training: bool
update_stats(batch: torch.Tensor) None

Update self.running_mean and self.running_var.

Reference Finch (2009), “Incremental calculation of weighted mean and variance”.

(https://fanf2.user.srcf.net/hermes/doc/antiforgery/stats.pdf)

Args:

batch: A batch of data to use to update the running mean and variance.

class networks.RunningNorm(num_features: int, eps: float = 1e-05)

Bases: networks.BaseNorm

Normalizes input to mean 0 and standard deviation 1 using a running average.

Similar to BatchNorm, LayerNorm, etc. but whereas they only use statistics from the current batch at train time, we use statistics from all batches.

This should closely replicate the common practice in RL of normalizing environment observations, such as using VecNormalize in Stable Baselines.

count: torch.Tensor
running_mean: torch.Tensor
running_var: torch.Tensor
training: bool
update_stats(batch: torch.Tensor) None

Update self.running_mean, self.running_var and self.count.

Uses Chan et al (1979), “Updating Formulae and a Pairwise Algorithm for Computing Sample Variances.” to update the running moments in a numerically stable fashion.

Args:

batch: A batch of data to use to update the running mean and variance.

class networks.SqueezeLayer

Bases: torch.nn.modules.module.Module

Torch module that squeezes a B*1 tensor down into a size-B vector.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
networks.build_mlp(in_size: int, hid_sizes: typing.Iterable[int], out_size: int = 1, name: typing.Optional[str] = None, activation: typing.Type[torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.ReLU'>, dropout_prob: float = 0.0, squeeze_output: bool = False, flatten_input: bool = False, normalize_input_layer: typing.Optional[typing.Type[torch.nn.modules.module.Module]] = None) torch.nn.modules.module.Module

Constructs a Torch MLP.

Args:
in_size: size of individual input vectors; input to the MLP will be of

shape (batch_size, in_size).

hid_sizes: sizes of hidden layers. If this is an empty iterable, then we build

a linear function approximator.

out_size: required size of output vector. name: Name to use as a prefix for the layers ID. activation: activation to apply after hidden layers. dropout_prob: Dropout probability to use after each hidden layer. If 0,

no dropout layers are added to the network.

squeeze_output: if out_size=1, then squeeze_input=True ensures that MLP

output is of size (B,) instead of (B,1).

flatten_input: should input be flattened along axes 1, 2, 3, …? Useful

if you want to, e.g., process small images inputs with an MLP.

normalize_input_layer: if specified, module to use to normalize inputs;

e.g. nn.BatchNorm or RunningNorm.

Returns:
nn.Module: an MLP mapping from inputs of size (batch_size, in_size) to

(batch_size, out_size), unless out_size=1 and squeeze_output=True, in which case the output is of size (batch_size, ).

Raises:

ValueError: if squeeze_output was supplied with out_size!=1.

networks.evaluating(m: torch.nn.modules.module.Module, *, mode: bool = False)

Temporarily switch module m to specified training mode.

Args:

m: The module to switch the mode of. mode: whether to set training mode (True) or evaluation (False).

Yields:

The module m.

networks.training(m: torch.nn.modules.module.Module, *, mode: bool = True)

Temporarily switch module m to specified training mode.

Args:

m: The module to switch the mode of. mode: whether to set training mode (True) or evaluation (False).

Yields:

The module m.

networks.training_mode(m: torch.nn.modules.module.Module, mode: bool = False)

Temporarily switch module m to specified training mode.

Args:

m: The module to switch the mode of. mode: whether to set training mode (True) or evaluation (False).

Yields:

The module m.