base module¶
Module of base classes and helper methods for imitation learning algorithms. Code adopted from https://github.com/HumanCompatibleAI/imitation.git
- class base.BaseImitationAlgorithm(*, custom_logger: Optional[logger.HierarchicalLogger] = None, allow_variable_horizon: bool = False)¶
Bases:
abc.ABC
Base class for all imitation learning algorithms.
- allow_variable_horizon: bool¶
If True, allow variable horizon trajectories; otherwise error if detected.
- property logger¶
- class base.DemonstrationAlgorithm(*, demonstrations: Optional[Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind]], custom_logger: Optional[logger.HierarchicalLogger] = None, allow_variable_horizon: bool = False)¶
Bases:
base.BaseImitationAlgorithm
,Generic
[TransitionKind
]An algorithm that learns from demonstration: BC, IRL, etc.
- allow_variable_horizon: bool¶
If True, allow variable horizon trajectories; otherwise error if detected.
- abstract property policy: stable_baselines3.common.policies.BasePolicy¶
Returns a policy imitating the demonstration data.
- abstract set_demonstrations(demonstrations: Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind]) None ¶
Sets the demonstration data.
Changing the demonstration data on-demand can be useful for interactive algorithms like DAgger.
- Args:
- demonstrations: Either a Torch DataLoader, any other iterator that
yields dictionaries containing “obs” and “acts” Tensors or NumPy arrays, TransitionKind instance, or a Sequence of Trajectory objects.
- base.make_data_loader(transitions: Union[Iterable[types_unique.Trajectory], Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]], TransitionKind], batch_size: int, data_loader_kwargs: Optional[Mapping[str, Any]] = None) Iterable[Mapping[str, Union[numpy.ndarray, torch.Tensor]]] ¶
Converts demonstration data to Torch data loader.
- Args:
- transitions: Transitions expressed directly as a types.TransitionsMinimal
object, a sequence of trajectories, or an iterable of transition batches (mappings from keywords to arrays containing observations, etc).
- batch_size: The size of the batch to create. Does not change the batch size
if transitions is already an iterable of transition batches.
data_loader_kwargs: Arguments to pass to th_data.DataLoader.
- Returns:
An iterable of transition batches.
- Raises:
- ValueError: if transitions is an iterable over transition batches with batch
size not equal to batch_size; or if transitions is transitions or a sequence of trajectories with total timesteps less than batch_size.
TypeError: if transitions is an unsupported type.