base module

Module of base classes and helper methods for imitation learning algorithms. Code adopted from https://github.com/HumanCompatibleAI/imitation.git

class base.BaseImitationAlgorithm(*, custom_logger: Optional[logger.HierarchicalLogger] = None, allow_variable_horizon: bool = False)

Bases: abc.ABC

Base class for all imitation learning algorithms.

allow_variable_horizon: bool

If True, allow variable horizon trajectories; otherwise error if detected.

property logger
class base.DemonstrationAlgorithm(*, demonstrations: Optional[Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind]], custom_logger: Optional[logger.HierarchicalLogger] = None, allow_variable_horizon: bool = False)

Bases: base.BaseImitationAlgorithm, Generic[TransitionKind]

An algorithm that learns from demonstration: BC, IRL, etc.

allow_variable_horizon: bool

If True, allow variable horizon trajectories; otherwise error if detected.

abstract property policy: stable_baselines3.common.policies.BasePolicy

Returns a policy imitating the demonstration data.

abstract set_demonstrations(demonstrations: Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind]) None

Sets the demonstration data.

Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.

Args:
demonstrations: Either a Torch DataLoader, any other iterator that

yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.

base.make_data_loader(transitions: Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind], batch_size: int, data_loader_kwargs: Optional[Mapping[str, Any]] = None) Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]]

Converts demonstration data to Torch data loader.

Args:
transitions: Transitions expressed directly as a types.TransitionsMinimal

object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc).

batch_size: The size of the batch to create. Does not change the batch size

if transitions is already an iterable of transition batches.

data_loader_kwargs: Arguments to pass to th_data.DataLoader.

Returns:

An iterable of transition batches.

Raises:
ValueError: if transitions is an iterable over transition batches with batch

size not equal to batch_size; or if transitions is transitions or a sequence of trajectories with total timesteps less than batch_size.

TypeError: if transitions is an unsupported type.