Adversarial IRL

Classes and Functions

Adversarial Inverse Reinforcement Learning (AIRL). Code adopted from https://github.com/HumanCompatibleAI/imitation.git

class airl.AIRL(*, demonstrations: Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind], demo_batch_size: int, venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv, gen_algo: stable_baselines3.common.base_class.BaseAlgorithm, reward_net: reward_nets.RewardNet, **kwargs)

Bases: common.AdversarialTrainer

Adversarial Inverse Reinforcement Learning (AIRL).

logits_gen_is_high(state: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor, log_policy_act_prob: torch.Tensor) torch.Tensor

Compute the discriminator’s logits for each state-action sample.

property reward_test: reward_nets.RewardNet

Returns the unshaped version of reward network used for testing.

property reward_train: reward_nets.RewardNet

Reward used to train generator policy.

venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv

The original vectorized environment.

venv_train: stable_baselines3.common.vec_env.base_vec_env.VecEnv

Like self.venv, but wrapped with train reward unless in debug mode.

If debug_use_ground_truth=True was passed into the initializer then self.venv_train is the same as self.venv.