Generative Adversarial Imitation Learning (GAIL)

Classes and Functions

Generative Adversarial Imitation Learning (GAIL). Code adopted from https://github.com/HumanCompatibleAI/imitation.git

class gail.GAIL(*, demonstrations: Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind], demo_batch_size: int, venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv, gen_algo: stable_baselines3.common.base_class.BaseAlgorithm, reward_net: reward_nets.RewardNet, **kwargs)

Bases: common.AdversarialTrainer

Generative Adversarial Imitation Learning (GAIL).

logits_gen_is_high(state: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor, log_policy_act_prob: Optional[torch.Tensor] = None) torch.Tensor

Compute the discriminator’s logits for each state-action sample.

property reward_test: reward_nets.RewardNet

Reward used to train policy at “test” time after adversarial training.

property reward_train: reward_nets.RewardNet

Reward used to train generator policy.

venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv

The original vectorized environment.

venv_train: stable_baselines3.common.vec_env.base_vec_env.VecEnv

Like self.venv, but wrapped with train reward unless in debug mode.

If debug_use_ground_truth=True was passed into the initializer then self.venv_train is the same as self.venv.

class gail.LogSigmoidRewardNet(base: reward_nets.RewardNet)

Bases: reward_nets.RewardNet

Wrapper for reward network that takes log sigmoid of wrapped network.

forward(state: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor) torch.Tensor

Computes negative log sigmoid of base reward network.

training: bool