buffer module¶
Buffers to store NumPy arrays and transitions in. Code adopted from https://github.com/HumanCompatibleAI/imitation.git
- class buffer.Buffer(capacity: int, sample_shapes: Mapping[str, Tuple[int, ...]], dtypes: Mapping[str, numpy.dtype])¶
Bases:
object
A FIFO ring buffer for NumPy arrays of a fixed shape and dtype.
Supports random sampling with replacement.
- capacity: int¶
The number of data samples that can be stored in this buffer.
- classmethod from_data(data: Mapping[str, numpy.ndarray], capacity: Optional[int] = None, truncate_ok: bool = False) buffer.Buffer ¶
Constructs and return a Buffer containing the provided data.
Shapes and dtypes are automatically inferred.
- Args:
- data: A dictionary mapping keys to data arrays. The arrays may differ
in their shape, but should agree in the first axis.
- capacity: The Buffer capacity. If not provided, then this is automatically
set to the size of the data, so that the returned Buffer is at full capacity.
- truncate_ok: Whether to error if capacity < the number of samples in
data. If False, then only store the last capacity samples from data when overcapacity.
- Examples:
In the follow examples, suppose the arrays in data are length-1000.
Buffer with same capacity as arrays in data:
Buffer.from_data(data)
Buffer with larger capacity than arrays in data:
Buffer.from_data(data, 10000)
Buffer with smaller capacity than arrays in `data. Without truncate_ok=True, from_data will error:
Buffer.from_data(data, 5, truncate_ok=True)
- Returns:
Buffer of specified capacity containing provided data.
- Raises:
ValueError: data is empty. ValueError: data has items mapping to arrays differing in the
length of their first axis.
- sample(n_samples: int) Mapping[str, numpy.ndarray] ¶
Uniformly sample n_samples samples from the buffer with replacement.
- Args:
n_samples: The number of samples to randomly sample.
- Returns:
- samples (np.ndarray): An array with shape
(n_samples) + self.sample_shape.
- Raises:
ValueError: The buffer is empty.
- sample_shapes: Mapping[str, Tuple[int, ...]]¶
The shapes of each data sample stored in this buffer.
- size() Optional[int] ¶
Returns the number of samples stored in the buffer.
- store(data: Mapping[str, numpy.ndarray], truncate_ok: bool = False) None ¶
Stores new data samples, replacing old samples with FIFO priority.
- Args:
- data: A dictionary mapping keys k to arrays with shape
(n_samples,) + self.sample_shapes[k], where n_samples is less than or equal to self.capacity.
- truncate_ok: If False, then error if the length of transitions is
greater than self.capacity. Otherwise, store only the final self.capacity transitions.
- Raises:
ValueError: data is empty. ValueError: If n_samples is greater than self.capacity. ValueError: data is the wrong shape.
- class buffer.ReplayBuffer(capacity: int, venv: Optional[stable_baselines3.common.vec_env.base_vec_env.VecEnv] = None, *, obs_shape: Optional[Tuple[int, ...]] = None, act_shape: Optional[Tuple[int, ...]] = None, obs_dtype: Optional[numpy.dtype] = None, act_dtype: Optional[numpy.dtype] = None)¶
Bases:
object
Buffer for Transitions.
- capacity: int¶
The number of data samples that can be stored in this buffer.
- classmethod from_data(transitions: types_unique.Transitions, capacity: Optional[int] = None, truncate_ok: bool = False) buffer.ReplayBuffer ¶
Construct and return a ReplayBuffer containing the provided data.
Shapes and dtypes are automatically inferred, and the returned ReplayBuffer is ready for sampling.
- Args:
transitions: Transitions to store. capacity: The ReplayBuffer capacity. If not provided, then this is
automatically set to the size of the data, so that the returned Buffer is at full capacity.
- truncate_ok: Whether to error if capacity < the number of samples in
data. If False, then only store the last capacity samples from data when overcapacity.
- Examples:
ReplayBuffer with same capacity as arrays in data:
ReplayBuffer.from_data(data)
ReplayBuffer with larger capacity than arrays in data:
ReplayBuffer.from_data(data, 10000)
ReplayBuffer with smaller capacity than arrays in `data. Without truncate_ok=True, from_data will error:
ReplayBuffer.from_data(data, 5, truncate_ok=True)
- Returns:
A new ReplayBuffer.
- sample(n_samples: int) types_unique.Transitions ¶
Sample obs-act-obs triples.
- Args:
n_samples: The number of samples.
- Returns:
A Transitions named tuple containing n_samples transitions.
- size() Optional[int] ¶
Returns the number of samples stored in the buffer.
- store(transitions: types_unique.Transitions, truncate_ok: bool = True) None ¶
Store obs-act-obs triples.
- Args:
transitions: Transitions to store. truncate_ok: If False, then error if the length of transitions is
greater than self.capacity. Otherwise, store only the final self.capacity transitions.
- Raises:
ValueError: The arguments didn’t have the same length.