common module¶
Core code for adversarial imitation learning, shared between GAIL and AIRL. Code adopted from https://github.com/HumanCompatibleAI/imitation.git
- class common.AdversarialTrainer(*, demonstrations: typing.Union[typing.Iterable[types_unique.Trajectory], typing.Iterable[typing.Mapping[str, typing.Union[numpy.ndarray, torch.Tensor]]], typing.TransitionKind], demo_batch_size: int, venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv, gen_algo: stable_baselines3.common.base_class.BaseAlgorithm, reward_net: reward_nets.RewardNet, n_disc_updates_per_round: int = 2, log_dir: str = 'output/', disc_opt_cls: typing.Type[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adam.Adam'>, disc_opt_kwargs: typing.Optional[typing.Mapping] = None, gen_train_timesteps: typing.Optional[int] = None, gen_replay_buffer_capacity: typing.Optional[int] = None, custom_logger: typing.Optional[logger.HierarchicalLogger] = None, init_tensorboard: bool = False, init_tensorboard_graph: bool = False, debug_use_ground_truth: bool = False, allow_variable_horizon: bool = True)¶
Bases:
base.DemonstrationAlgorithm
[types_unique.Transitions
]Base class for adversarial imitation learning algorithms like GAIL and AIRL.
- abstract logits_gen_is_high(state: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor, log_policy_act_prob: Optional[torch.Tensor] = None) torch.Tensor ¶
Compute the discriminator’s logits for each state-action sample.
A high value corresponds to predicting generator, and a low value corresponds to predicting expert.
- Args:
state: state at time t, of shape (batch_size,) + state_shape. action: action taken at time t, of shape (batch_size,) + action_shape. next_state: state at time t+1, of shape (batch_size,) + state_shape. done: binary episode completion flag after action at time t,
of shape (batch_size,).
- log_policy_act_prob: log probability of generator policy taking
action at time t.
- Returns:
Discriminator logits of shape (batch_size,). A high output indicates a generator-like transition.
- property policy: stable_baselines3.common.policies.BasePolicy¶
Returns a policy imitating the demonstration data.
- abstract property reward_test: reward_nets.RewardNet¶
Reward used to train policy at “test” time after adversarial training.
- abstract property reward_train: reward_nets.RewardNet¶
Reward used to train generator policy.
- set_demonstrations(demonstrations: Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind]) None ¶
Sets the demonstration data.
Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.
- Args:
- demonstrations: Either a Torch DataLoader, any other iterator that
yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.
- train(total_timesteps: int, callback: Optional[Callable[int, None]] = None) None ¶
Alternates between training the generator and discriminator.
Every “round” consists of a call to train_gen(self.gen_train_timesteps), a call to train_disc, and finally a call to callback(round).
Training ends once an additional “round” would cause the number of transitions sampled from the environment to exceed total_timesteps.
- Args:
- total_timesteps: An upper bound on the number of transitions to sample
from the environment during training.
- callback: A function called at the end of every round which takes in a
single argument, the round number. Round numbers are in range(total_timesteps // self.gen_train_timesteps).
- train_disc(*, expert_samples: Optional[Mapping] = None, gen_samples: Optional[Mapping] = None) Optional[Mapping[str, float]] ¶
Perform a single discriminator update, optionally using provided samples.
- Args:
- expert_samples: Transition samples from the expert in dictionary form.
If provided, must contain keys corresponding to every field of the Transitions dataclass except “infos”. All corresponding values can be either NumPy arrays or Tensors. Extra keys are ignored. Must contain self.demo_batch_size samples. If this argument is not provided, then self.demo_batch_size expert samples from self.demo_data_loader are used by default.
- gen_samples: Transition samples from the generator policy in same dictionary
form as expert_samples. If provided, must contain exactly self.demo_batch_size samples. If not provided, then take len(expert_samples) samples from the generator replay buffer.
- Returns:
Statistics for discriminator (e.g. loss, accuracy).
- train_gen(total_timesteps: Optional[int] = None, learn_kwargs: Optional[Mapping] = None) None ¶
Trains the generator to maximize the discriminator loss.
After the end of training populates the generator replay buffer (used in discriminator training) with self.disc_batch_size transitions.
- Args:
- total_timesteps: The number of transitions to sample from
self.venv_train during training. By default, self.gen_train_timesteps.
- learn_kwargs: kwargs for the Stable Baselines RLModel.learn()
method.
- venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv¶
The original vectorized environment.
- venv_train: stable_baselines3.common.vec_env.base_vec_env.VecEnv¶
Like self.venv, but wrapped with train reward unless in debug mode.
If debug_use_ground_truth=True was passed into the initializer then self.venv_train is the same as self.venv.
- common.compute_train_stats(disc_logits_gen_is_high: torch.Tensor, labels_gen_is_one: torch.Tensor, disc_loss: torch.Tensor) Mapping[str, float] ¶
Train statistics for GAIL/AIRL discriminator.
- Args:
- disc_logits_gen_is_high: discriminator logits produced by
DiscrimNet.logits_gen_is_high.
- labels_gen_is_one: integer labels describing whether logit was for an
expert (0) or generator (1) sample.
disc_loss: final discriminator loss.
- Returns:
A mapping from statistic names to float values.