base_policy module¶
Custom policy classes and convenience methods. Code adopted from https://github.com/HumanCompatibleAI/imitation.git
- class base_policy.FeedForward32Policy(*args, **kwargs)¶
Bases:
stable_baselines3.common.policies.ActorCriticPolicyA feed forward policy network with two hidden layers of 32 units.
This matches the IRL policies in the original AIRL paper.
Note: This differs from stable_baselines3 ActorCriticPolicy in two ways: by having 32 rather than 64 units, and by having policy and value networks share weights except at the final layer, where there are different linear heads.
- training: bool¶
- class base_policy.HardCodedPolicy(observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space)¶
Bases:
stable_baselines3.common.policies.BasePolicy,abc.ABCAbstract class for hard-coded (non-trainable) policies.
- forward(*args)¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class base_policy.NormalizeFeaturesExtractor(observation_space: gym.spaces.space.Space, normalize_class: typing.Type[torch.nn.modules.module.Module] = <class 'networks.RunningNorm'>)¶
Bases:
stable_baselines3.common.torch_layers.FlattenExtractorFeature extractor that flattens then normalizes input.
- forward(observations: torch.Tensor) torch.Tensor¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool¶
- class base_policy.RandomPolicy(observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space)¶
Bases:
base_policy.HardCodedPolicyReturns random actions.
- training: bool¶
- class base_policy.SAC1024Policy(*args, **kwargs)¶
Bases:
stable_baselines3.sac.policies.SACPolicyActor and value networks with two hidden layers of 1024 units respectively.
This matches the implementation of SAC policies in the PEBBLE paper. See: https://arxiv.org/pdf/2106.05091.pdf https://github.com/denisyarats/pytorch_sac/blob/master/config/agent/sac.yaml
Note: This differs from stable_baselines3 SACPolicy by having 1024 hidden units in each layer instead of the default value of 256.
- training: bool¶
- class base_policy.ZeroPolicy(observation_space: gym.spaces.space.Space, action_space: gym.spaces.space.Space)¶
Bases:
base_policy.HardCodedPolicyReturns constant zero action.
- training: bool¶