Generative Adversarial Imitation Learning (GAIL)¶
Classes and Functions¶
Generative Adversarial Imitation Learning (GAIL). Code adopted from https://github.com/HumanCompatibleAI/imitation.git
- class gail.GAIL(*, demonstrations: Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind], demo_batch_size: int, venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv, gen_algo: stable_baselines3.common.base_class.BaseAlgorithm, reward_net: reward_nets.RewardNet, **kwargs)¶
Bases:
common.AdversarialTrainer
Generative Adversarial Imitation Learning (GAIL).
- logits_gen_is_high(state: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor, log_policy_act_prob: Optional[torch.Tensor] = None) torch.Tensor ¶
Compute the discriminator’s logits for each state-action sample.
- property reward_test: reward_nets.RewardNet¶
Reward used to train policy at “test” time after adversarial training.
- property reward_train: reward_nets.RewardNet¶
Reward used to train generator policy.
- venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv¶
The original vectorized environment.
- venv_train: stable_baselines3.common.vec_env.base_vec_env.VecEnv¶
Like self.venv, but wrapped with train reward unless in debug mode.
If debug_use_ground_truth=True was passed into the initializer then self.venv_train is the same as self.venv.
- class gail.LogSigmoidRewardNet(base: reward_nets.RewardNet)¶
Bases:
reward_nets.RewardNet
Wrapper for reward network that takes log sigmoid of wrapped network.
- forward(state: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor) torch.Tensor ¶
Computes negative log sigmoid of base reward network.
- training: bool¶