Adversarial IRL¶
Classes and Functions¶
Adversarial Inverse Reinforcement Learning (AIRL). Code adopted from https://github.com/HumanCompatibleAI/imitation.git
- class airl.AIRL(*, demonstrations: Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind], demo_batch_size: int, venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv, gen_algo: stable_baselines3.common.base_class.BaseAlgorithm, reward_net: reward_nets.RewardNet, **kwargs)¶
Bases:
common.AdversarialTrainer
Adversarial Inverse Reinforcement Learning (AIRL).
- logits_gen_is_high(state: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor, log_policy_act_prob: torch.Tensor) torch.Tensor ¶
Compute the discriminator’s logits for each state-action sample.
- property reward_test: reward_nets.RewardNet¶
Returns the unshaped version of reward network used for testing.
- property reward_train: reward_nets.RewardNet¶
Reward used to train generator policy.
- venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv¶
The original vectorized environment.
- venv_train: stable_baselines3.common.vec_env.base_vec_env.VecEnv¶
Like self.venv, but wrapped with train reward unless in debug mode.
If debug_use_ground_truth=True was passed into the initializer then self.venv_train is the same as self.venv.