base_policy module

Custom policy classes and convenience methods. Code adopted from https://github.com/HumanCompatibleAI/imitation.git

class base_policy.FeedForward32Policy(*args, **kwargs)

Bases: stable_baselines3.common.policies.ActorCriticPolicy

A feed forward policy network with two hidden layers of 32 units.

This matches the IRL policies in the original AIRL paper.

Note: This differs from stable_baselines3 ActorCriticPolicy in two ways: by having 32 rather than 64 units, and by having policy and value networks share weights except at the final layer, where there are different linear heads.

training: bool
class base_policy.HardCodedPolicy(observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space)

Bases: stable_baselines3.common.policies.BasePolicy, abc.ABC

Abstract class for hard-coded (non-trainable) policies.

forward(*args)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class base_policy.NormalizeFeaturesExtractor(observation_space: gym.spaces.space.Space, normalize_class: typing.Type[torch.nn.modules.module.Module] = <class 'networks.RunningNorm'>)

Bases: stable_baselines3.common.torch_layers.FlattenExtractor

Feature extractor that flattens then normalizes input.

forward(observations: torch.Tensor) torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class base_policy.RandomPolicy(observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space)

Bases: base_policy.HardCodedPolicy

Returns random actions.

training: bool
class base_policy.SAC1024Policy(*args, **kwargs)

Bases: stable_baselines3.sac.policies.SACPolicy

Actor and value networks with two hidden layers of 1024 units respectively.

This matches the implementation of SAC policies in the PEBBLE paper. See: https://arxiv.org/pdf/2106.05091.pdf https://github.com/denisyarats/pytorch_sac/blob/master/config/agent/sac.yaml

Note: This differs from stable_baselines3 SACPolicy by having 1024 hidden units in each layer instead of the default value of 256.

training: bool
class base_policy.ZeroPolicy(observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space)

Bases: base_policy.HardCodedPolicy

Returns constant zero action.

training: bool