sup3r.utilities.pytest.helpers.BatchHandlerTesterDC

sup3r.utilities.pytest.helpers.BatchHandlerTesterDC#

class BatchHandlerTesterDC(train_containers, *, val_containers=None, sample_shape=None, batch_size=16, n_batches=64, s_enhance=1, t_enhance=1, means=None, stds=None, queue_cap=None, transform_kwargs=None, mode='lazy', feature_sets=None, spatial_weights=None, temporal_weights=None, n_space_bins=1, n_time_bins=1)[source]#

Bases: BatchHandlerDC

Data-centric batch handler with record for testing

Parameters:
  • train_containers (List[Container]) – List of objects with a .data attribute, which will be used to initialize Sampler objects and then used to initialize a batch queue of training data. The data can be a Sup3rX or Sup3rDataset object.

  • val_containers (List[Container]) – List of objects with a .data attribute, which will be used to initialize Sampler objects and then used to initialize a batch queue of validation data. The data can be a Sup3rX or a Sup3rDataset object.

  • batch_size (int) – Number of observations / samples in a batch

  • n_batches (int) – Number of batches in an epoch, this sets the iteration limit for this object.

  • s_enhance (int) – Integer factor by which the spatial axes is to be enhanced.

  • t_enhance (int) – Integer factor by which the temporal axes is to be enhanced.

  • means (str | dict | None) – Usually a file path for loading / saving results, or None for just calculating stats and not saving. Can also be a dict.

  • stds (str | dict | None) – Usually a file path for loading / saving results, or None for just calculating stats and not saving. Can also be a dict.

  • queue_cap (int) – Maximum number of batches the batch queue can store.

  • transform_kwargs (Union[Dict, None]) – Dictionary of kwargs to be passed to self.transform. This method performs smoothing / coarsening.

  • mode (str) – Loading mode. Default is ‘lazy’, which only loads data into memory as batches are queued. ‘eager’ will load all data into memory right away.

  • feature_sets (Optional[dict]) – Optional dictionary describing how the full set of features is split between lr_only_features and hr_exo_features. See Sampler

  • sample_shape (tuple) – Size of arrays to sample from the contained data.

  • spatial_weights (Union[np.ndarray, da.core.Array] | List | None) – Set of weights used to initialize the spatial sampling. e.g. If we want to start off sampling across 2 spatial bins evenly this should be [0.5, 0.5]. During training these weights will be updated based only performance across the bins associated with these weights.

  • temporal_weights (Union[np.ndarray, da.core.Array] | List | None) – Set of weights used to initialize the temporal sampling. e.g. If we want to start off sampling only the first season of the year this should be [1, 0, 0, 0]. During training these weights will be updated based only performance across the bins associated with these weights.

  • n_space_bins (int) – Number of spatial bins to use for weighted sampling. e.g. if this is 4 the spatial domain will be divided into 4 equal regions and losses will be calculated across these regions during traning in order to adaptively sample from lower performing regions.

  • n_time_bins (int) – Number of time bins to use for weighted sampling. e.g. if this is 4 the temporal domain will be divided into 4 equal periods and losses will be calculated across these periods during traning in order to adaptively sample from lower performing time periods.

Methods

check_enhancement_factors()

Make sure the enhancement factors evenly divide the sample_shape.

check_features()

Make sure all samplers have the same sets of features.

check_shared_attr(attr)

Check if all containers have the same value for attr.

enqueue_batch()

Build batch and send to queue.

enqueue_batches()

Callback function for queue thread.

get_batch()

Get batch from queue or directly from a Sampler through sample_batch.

get_container_index()

Get random container index based on weights

get_queue()

Return FIFO queue for storing batches.

get_random_container()

Get random container based on container weights

init_samplers(train_containers, ...)

Initialize samplers from given data containers.

log_queue_info()

Log info about queue size.

post_init_log([args_dict])

Log additional arguments after initialization.

post_proc(samples)

Performs some post proc on dequeued samples before sending out for training.

preflight()

Run checks before kicking off the queue.

sample_batch()

Override get_samples to track sample indices.

start()

Start the val data batch queue in addition to the train batch queue.

stop()

Stop the val data batch queue in addition to the train batch queue.

transform(samples[, smoothing, ...])

Coarsen high res data to get corresponding low res batch.

update_record()

Reset records for a new epoch.

update_weights(spatial_weights, temporal_weights)

Set weights used to sample spatial and temporal bins.

wrap(data)

Return a Sup3rDataset object or tuple of such.

Attributes

container_weights

Get weights used to sample from different containers based on relative sizes

data

Return underlying data.

features

Get all features contained in data.

hr_shape

Shape of high resolution sample in a low-res / high-res pair.

lr_shape

Shape of low resolution sample in a low-res / high-res pair.

queue_shape

Shape of objects stored in the queue.

queue_thread

Get new queue thread.

running

Boolean to check whether to keep enqueueing batches.

shape

Get shape of underlying data.

spatial_weights

Get weights used to sample spatial bins.

temporal_weights

Get weights used to sample temporal bins.

SAMPLER#

alias of SamplerTester

sample_batch()[source]#

Override get_samples to track sample indices.

update_record()[source]#

Reset records for a new epoch.

class Batch(low_res, high_res)#

Bases: tuple

Create new instance of Batch(low_res, high_res)

__add__(value, /)#

Return self+value.

__mul__(value, /)#

Return self*value.

count(value, /)#

Return number of occurrences of value.

high_res#

Alias for field number 1

index(value, start=0, stop=9223372036854775807, /)#

Return first index of value.

Raises ValueError if the value is not present.

low_res#

Alias for field number 0

TRAIN_QUEUE#

alias of BatchQueueDC

VAL_QUEUE#

alias of ValBatchQueueDC

check_enhancement_factors()#

Make sure the enhancement factors evenly divide the sample_shape.

check_features()#

Make sure all samplers have the same sets of features.

check_shared_attr(attr)#

Check if all containers have the same value for attr. If they do the collection effectively inherits those attributes.

property container_weights#

Get weights used to sample from different containers based on relative sizes

property data#

Return underlying data.

Returns:

Sup3rDataset

See also

wrap()

enqueue_batch()#

Build batch and send to queue.

enqueue_batches() None#

Callback function for queue thread. While training, the queue is checked for empty spots and filled. In the training thread, batches are removed from the queue.

property features#

Get all features contained in data.

get_batch() Batch#

Get batch from queue or directly from a Sampler through sample_batch.

get_container_index()#

Get random container index based on weights

get_queue()#

Return FIFO queue for storing batches.

get_random_container()#

Get random container based on container weights

property hr_shape#

Shape of high resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))

init_samplers(train_containers, val_containers, sample_shape, feature_sets, batch_size, sampler_kwargs)#

Initialize samplers from given data containers.

log_queue_info()#

Log info about queue size.

property lr_shape#

Shape of low resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))

post_init_log(args_dict=None)#

Log additional arguments after initialization.

post_proc(samples) Batch#

Performs some post proc on dequeued samples before sending out for training. Post processing can include coarsening on high-res data (if Collection consists of Sampler objects and not DualSampler objects), smoothing, etc

Returns:

Batch (namedtuple) – namedtuple with low_res and high_res attributes

preflight()#

Run checks before kicking off the queue.

property queue_shape#

Shape of objects stored in the queue.

property queue_thread#

Get new queue thread.

property running#

Boolean to check whether to keep enqueueing batches.

property shape#

Get shape of underlying data.

property spatial_weights#

Get weights used to sample spatial bins.

start()#

Start the val data batch queue in addition to the train batch queue.

stop()#

Stop the val data batch queue in addition to the train batch queue.

property temporal_weights#

Get weights used to sample temporal bins.

transform(samples, smoothing=None, smoothing_ignore=None, temporal_coarsening_method='subsample')#

Coarsen high res data to get corresponding low res batch.

Parameters:
  • samples (Union[np.ndarray, da.core.Array]) – High resolution batch of samples. 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)

  • smoothing (float | None) – Standard deviation to use for gaussian filtering of the coarse data. This can be tuned by matching the kinetic energy of a low resolution simulation with the kinetic energy of a coarsened and smoothed high resolution simulation. If None no smoothing is performed.

  • smoothing_ignore (list | None) – List of features to ignore for the smoothing filter. None will smooth all features if smoothing kwarg is not None

  • temporal_coarsening_method (str) – Method to use for temporal coarsening. Can be subsample, average, min, max, or total

Returns:

  • low_res (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)

  • high_res (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)

update_weights(spatial_weights, temporal_weights)#

Set weights used to sample spatial and temporal bins. This is called by Sup3rGanDC after an epoch to update weights based on model performance across validation samples.

wrap(data)#

Return a Sup3rDataset object or tuple of such. This is a tuple when the .data attribute belongs to a Collection object like BatchHandler. Otherwise this is Sup3rDataset object, which is either a wrapped 2-tuple or 1-tuple (e.g. len(data) == 2 or len(data) == 1). This is a 2-tuple when .data belongs to a dual container object like DualSampler and a 1-tuple otherwise.