common module

Core code for adversarial imitation learning, shared between GAIL and AIRL. Code adopted from https://github.com/HumanCompatibleAI/imitation.git

class common.AdversarialTrainer(*, demonstrations: typing.Union[typing.Iterable[types_unique.Trajectory], typing.Iterable[typing.Mapping[str, typing.Union[numpy.ndarray, torch.Tensor]]], typing.TransitionKind], demo_batch_size: int, venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv, gen_algo: stable_baselines3.common.base_class.BaseAlgorithm, reward_net: reward_nets.RewardNet, n_disc_updates_per_round: int = 2, log_dir: str = 'output/', disc_opt_cls: typing.Type[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adam.Adam'>, disc_opt_kwargs: typing.Optional[typing.Mapping] = None, gen_train_timesteps: typing.Optional[int] = None, gen_replay_buffer_capacity: typing.Optional[int] = None, custom_logger: typing.Optional[logger.HierarchicalLogger] = None, init_tensorboard: bool = False, init_tensorboard_graph: bool = False, debug_use_ground_truth: bool = False, allow_variable_horizon: bool = True)

Bases: base.DemonstrationAlgorithm[types_unique.Transitions]

Base class for adversarial imitation learning algorithms like GAIL and AIRL.

abstract logits_gen_is_high(state: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor, done: torch.Tensor, log_policy_act_prob: Optional[torch.Tensor] = None) torch.Tensor

Compute the discriminator’s logits for each state-action sample.

A high value corresponds to predicting generator, and a low value corresponds to predicting expert.

Args:

state: state at time t, of shape (batch_size,) + state_shape. action: action taken at time t, of shape (batch_size,) + action_shape. next_state: state at time t+1, of shape (batch_size,) + state_shape. done: binary episode completion flag after action at time t,

of shape (batch_size,).

log_policy_act_prob: log probability of generator policy taking

action at time t.

Returns:

Discriminator logits of shape (batch_size,). A high output indicates a generator-like transition.

property policy: stable_baselines3.common.policies.BasePolicy

Returns a policy imitating the demonstration data.

abstract property reward_test: reward_nets.RewardNet

Reward used to train policy at “test” time after adversarial training.

abstract property reward_train: reward_nets.RewardNet

Reward used to train generator policy.

set_demonstrations(demonstrations: Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind]) None

Sets the demonstration data.

Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.

Args:
demonstrations: Either a Torch DataLoader, any other iterator that

yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.

train(total_timesteps: int, callback: Optional[Callable[int, None]] = None) None

Alternates between training the generator and discriminator.

Every “round” consists of a call to train_gen(self.gen_train_timesteps), a call to train_disc, and finally a call to callback(round).

Training ends once an additional “round” would cause the number of transitions sampled from the environment to exceed total_timesteps.

Args:
total_timesteps: An upper bound on the number of transitions to sample

from the environment during training.

callback: A function called at the end of every round which takes in a

single argument, the round number. Round numbers are in range(total_timesteps // self.gen_train_timesteps).

train_disc(*, expert_samples: Optional[Mapping] = None, gen_samples: Optional[Mapping] = None) Optional[Mapping[str, float]]

Perform a single discriminator update, optionally using provided samples.

Args:
expert_samples: Transition samples from the expert in dictionary form.

If provided, must contain keys corresponding to every field of the Transitions dataclass except “infos”. All corresponding values can be either NumPy arrays or Tensors. Extra keys are ignored. Must contain self.demo_batch_size samples. If this argument is not provided, then self.demo_batch_size expert samples from self.demo_data_loader are used by default.

gen_samples: Transition samples from the generator policy in same dictionary

form as expert_samples. If provided, must contain exactly self.demo_batch_size samples. If not provided, then take len(expert_samples) samples from the generator replay buffer.

Returns:

Statistics for discriminator (e.g. loss, accuracy).

train_gen(total_timesteps: Optional[int] = None, learn_kwargs: Optional[Mapping] = None) None

Trains the generator to maximize the discriminator loss.

After the end of training populates the generator replay buffer (used in discriminator training) with self.disc_batch_size transitions.

Args:
total_timesteps: The number of transitions to sample from

self.venv_train during training. By default, self.gen_train_timesteps.

learn_kwargs: kwargs for the Stable Baselines RLModel.learn()

method.

venv: stable_baselines3.common.vec_env.base_vec_env.VecEnv

The original vectorized environment.

venv_train: stable_baselines3.common.vec_env.base_vec_env.VecEnv

Like self.venv, but wrapped with train reward unless in debug mode.

If debug_use_ground_truth=True was passed into the initializer then self.venv_train is the same as self.venv.

common.compute_train_stats(disc_logits_gen_is_high: torch.Tensor, labels_gen_is_one: torch.Tensor, disc_loss: torch.Tensor) Mapping[str, float]

Train statistics for GAIL/AIRL discriminator.

Args:
disc_logits_gen_is_high: discriminator logits produced by

DiscrimNet.logits_gen_is_high.

labels_gen_is_one: integer labels describing whether logit was for an

expert (0) or generator (1) sample.

disc_loss: final discriminator loss.

Returns:

A mapping from statistic names to float values.