sup3r.preprocessing.batch_queues.dc.BatchQueueDC#

class BatchQueueDC(samplers, *, n_space_bins=1, n_time_bins=1, batch_size=16, n_batches=64, s_enhance=1, t_enhance=1, queue_cap=None, transform_kwargs=None, thread_name='training', mode='lazy', **kwargs)[source]#

Bases: SingleBatchQueue

Sample from data based on spatial and temporal weights. These weights can be derived from validation training losses and updated during training or set a priori to construct a validation queue

Parameters:
  • samplers (List[Sampler]) – List of Sampler instances

  • n_space_bins (int) – Number of spatial bins to use for weighted sampling. e.g. if this is 4 the spatial domain will be divided into 4 equal regions and losses will be calculated across these regions during traning in order to adaptively sample from lower performing regions.

  • n_time_bins (int) – Number of time bins to use for weighted sampling. e.g. if this is 4 the temporal domain will be divided into 4 equal periods and losses will be calculated across these periods during traning in order to adaptively sample from lower performing time periods.

  • kwargs (dict) – Keyword arguments for parent class.

  • batch_size (int) – Number of observations / samples in a batch

  • n_batches (int) – Number of batches in an epoch, this sets the iteration limit for this object.

  • s_enhance (int) – Integer factor by which the spatial axes is to be enhanced.

  • t_enhance (int) – Integer factor by which the temporal axes is to be enhanced.

  • queue_cap (int) – Maximum number of batches the batch queue can store.

  • transform_kwargs (Union[Dict, None]) – Dictionary of kwargs to be passed to self.transform. This method performs smoothing / coarsening.

  • thread_name (str) – Name of the queue thread. Default is ‘training’. Used to set name to ‘validation’ for BatchHandler, which has a training and validation queue.

  • mode (str) – Loading mode. Default is ‘lazy’, which only loads data into memory as batches are queued. ‘eager’ will load all data into memory right away.

Methods

check_enhancement_factors()

Make sure the enhancement factors evenly divide the sample_shape.

check_features()

Make sure all samplers have the same sets of features.

check_shared_attr(attr)

Check if all containers have the same value for attr.

enqueue_batch()

Build batch and send to queue.

enqueue_batches()

Callback function for queue thread.

get_batch()

Get batch from queue or directly from a Sampler through sample_batch.

get_container_index()

Get random container index based on weights

get_queue()

Return FIFO queue for storing batches.

get_random_container()

Get random container based on container weights

log_queue_info()

Log info about queue size.

post_init_log([args_dict])

Log additional arguments after initialization.

post_proc(samples)

Performs some post proc on dequeued samples before sending out for training.

preflight()

Run checks before kicking off the queue.

sample_batch()

Update weights and get batch of samples from sampled container.

start()

Start thread to keep sample queue full for batches.

stop()

Stop loading batches.

transform(samples[, smoothing, ...])

Coarsen high res data to get corresponding low res batch.

update_weights(spatial_weights, temporal_weights)

Set weights used to sample spatial and temporal bins.

wrap(data)

Return a Sup3rDataset object or tuple of such.

Attributes

container_weights

Get weights used to sample from different containers based on relative sizes

data

Return underlying data.

features

Get all features contained in data.

hr_shape

Shape of high resolution sample in a low-res / high-res pair.

lr_shape

Shape of low resolution sample in a low-res / high-res pair.

queue_shape

Shape of objects stored in the queue.

queue_thread

Get new queue thread.

running

Boolean to check whether to keep enqueueing batches.

shape

Get shape of underlying data.

spatial_weights

Get weights used to sample spatial bins.

temporal_weights

Get weights used to sample temporal bins.

sample_batch()[source]#

Update weights and get batch of samples from sampled container.

property spatial_weights#

Get weights used to sample spatial bins.

property temporal_weights#

Get weights used to sample temporal bins.

update_weights(spatial_weights, temporal_weights)[source]#

Set weights used to sample spatial and temporal bins. This is called by Sup3rGanDC after an epoch to update weights based on model performance across validation samples.

class Batch(low_res, high_res)#

Bases: tuple

Create new instance of Batch(low_res, high_res)

__add__(value, /)#

Return self+value.

__mul__(value, /)#

Return self*value.

count(value, /)#

Return number of occurrences of value.

high_res#

Alias for field number 1

index(value, start=0, stop=9223372036854775807, /)#

Return first index of value.

Raises ValueError if the value is not present.

low_res#

Alias for field number 0

check_enhancement_factors()#

Make sure the enhancement factors evenly divide the sample_shape.

check_features()#

Make sure all samplers have the same sets of features.

check_shared_attr(attr)#

Check if all containers have the same value for attr. If they do the collection effectively inherits those attributes.

property container_weights#

Get weights used to sample from different containers based on relative sizes

property data#

Return underlying data.

Returns:

Sup3rDataset

See also

wrap()

enqueue_batch()#

Build batch and send to queue.

enqueue_batches() None#

Callback function for queue thread. While training, the queue is checked for empty spots and filled. In the training thread, batches are removed from the queue.

property features#

Get all features contained in data.

get_batch() Batch#

Get batch from queue or directly from a Sampler through sample_batch.

get_container_index()#

Get random container index based on weights

get_queue()#

Return FIFO queue for storing batches.

get_random_container()#

Get random container based on container weights

property hr_shape#

Shape of high resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))

log_queue_info()#

Log info about queue size.

property lr_shape#

Shape of low resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))

post_init_log(args_dict=None)#

Log additional arguments after initialization.

post_proc(samples) Batch#

Performs some post proc on dequeued samples before sending out for training. Post processing can include coarsening on high-res data (if Collection consists of Sampler objects and not DualSampler objects), smoothing, etc

Returns:

Batch (namedtuple) – namedtuple with low_res and high_res attributes

preflight()#

Run checks before kicking off the queue.

property queue_shape#

Shape of objects stored in the queue.

property queue_thread#

Get new queue thread.

property running#

Boolean to check whether to keep enqueueing batches.

property shape#

Get shape of underlying data.

start() None#

Start thread to keep sample queue full for batches.

stop() None#

Stop loading batches.

transform(samples, smoothing=None, smoothing_ignore=None, temporal_coarsening_method='subsample')#

Coarsen high res data to get corresponding low res batch.

Parameters:
  • samples (Union[np.ndarray, da.core.Array]) – High resolution batch of samples. 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)

  • smoothing (float | None) – Standard deviation to use for gaussian filtering of the coarse data. This can be tuned by matching the kinetic energy of a low resolution simulation with the kinetic energy of a coarsened and smoothed high resolution simulation. If None no smoothing is performed.

  • smoothing_ignore (list | None) – List of features to ignore for the smoothing filter. None will smooth all features if smoothing kwarg is not None

  • temporal_coarsening_method (str) – Method to use for temporal coarsening. Can be subsample, average, min, max, or total

Returns:

  • low_res (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)

  • high_res (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)

wrap(data)#

Return a Sup3rDataset object or tuple of such. This is a tuple when the .data attribute belongs to a Collection object like BatchHandler. Otherwise this is Sup3rDataset object, which is either a wrapped 2-tuple or 1-tuple (e.g. len(data) == 2 or len(data) == 1). This is a 2-tuple when .data belongs to a dual container object like DualSampler and a 1-tuple otherwise.