base_policy module¶
Custom policy classes and convenience methods. Code adopted from https://github.com/HumanCompatibleAI/imitation.git
- class base_policy.FeedForward32Policy(*args, **kwargs)¶
Bases:
stable_baselines3.common.policies.ActorCriticPolicy
A feed forward policy network with two hidden layers of 32 units.
This matches the IRL policies in the original AIRL paper.
Note: This differs from stable_baselines3 ActorCriticPolicy in two ways: by having 32 rather than 64 units, and by having policy and value networks share weights except at the final layer, where there are different linear heads.
- training: bool¶
- class base_policy.HardCodedPolicy(observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space)¶
Bases:
stable_baselines3.common.policies.BasePolicy
,abc.ABC
Abstract class for hard-coded (non-trainable) policies.
- forward(*args)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class base_policy.NormalizeFeaturesExtractor(observation_space: gym.spaces.space.Space, normalize_class: typing.Type[torch.nn.modules.module.Module] = <class 'networks.RunningNorm'>)¶
Bases:
stable_baselines3.common.torch_layers.FlattenExtractor
Feature extractor that flattens then normalizes input.
- forward(observations: torch.Tensor) torch.Tensor ¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class base_policy.RandomPolicy(observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space)¶
Bases:
base_policy.HardCodedPolicy
Returns random actions.
- training: bool¶
- class base_policy.SAC1024Policy(*args, **kwargs)¶
Bases:
stable_baselines3.sac.policies.SACPolicy
Actor and value networks with two hidden layers of 1024 units respectively.
This matches the implementation of SAC policies in the PEBBLE paper. See: https://arxiv.org/pdf/2106.05091.pdf https://github.com/denisyarats/pytorch_sac/blob/master/config/agent/sac.yaml
Note: This differs from stable_baselines3 SACPolicy by having 1024 hidden units in each layer instead of the default value of 256.
- training: bool¶
- class base_policy.ZeroPolicy(observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space)¶
Bases:
base_policy.HardCodedPolicy
Returns constant zero action.
- training: bool¶