sup3r.preprocessing.batch_queues.conditional.QueueMom2Sep#

class QueueMom2Sep(samplers, *, time_enhance_mode='constant', lower_models=None, s_padding=0, t_padding=0, end_t_padding=False, batch_size=16, n_batches=64, s_enhance=1, t_enhance=1, queue_cap=None, transform_kwargs=None, thread_name='training', mode='lazy', **kwargs)[source]#

Bases: QueueMom1

Batch handling class for conditional estimation of second moment without subtraction of first moment

Parameters:
  • samplers (List[Sampler]) – List of Sampler instances

  • time_enhance_mode (str) – [constant, linear] Method to enhance temporally when constructing subfilter. At every temporal location, a low-res temporal data is subtracted from the high-res temporal data predicted. constant will assume that the low-res temporal data is constant between landmarks. linear will linearly interpolate between landmarks to generate the low-res data to remove from the high-res.

  • lower_models (Dict[int, Sup3rCondMom] | None) – Dictionary of models that predict lower moments. For example, if this queue is part of a handler to estimate the 3rd moment lower_models could include models that estimate the 1st and 2nd moments. These lower moments can be required in higher order moment calculations.

  • s_padding (int | None) – Width of spatial padding to predict only middle part. If None, no padding is used

  • t_padding (int | None) – Width of temporal padding to predict only middle part. If None, no padding is used

  • end_t_padding (bool | False) – Zero pad the end of temporal space. Ensures that loss is calculated only if snapshot is surrounded by temporal landmarks. False by default

  • kwargs (dict) – Keyword arguments for parent class

  • batch_size (int) – Number of observations / samples in a batch

  • n_batches (int) – Number of batches in an epoch, this sets the iteration limit for this object.

  • s_enhance (int) – Integer factor by which the spatial axes is to be enhanced.

  • t_enhance (int) – Integer factor by which the temporal axes is to be enhanced.

  • queue_cap (int) – Maximum number of batches the batch queue can store.

  • transform_kwargs (Union[Dict, None]) – Dictionary of kwargs to be passed to self.transform. This method performs smoothing / coarsening.

  • thread_name (str) – Name of the queue thread. Default is ‘training’. Used to set name to ‘validation’ for BatchHandler, which has a training and validation queue.

  • mode (str) – Loading mode. Default is ‘lazy’, which only loads data into memory as batches are queued. ‘eager’ will load all data into memory right away.

Methods

check_enhancement_factors()

Make sure the enhancement factors evenly divide the sample_shape.

check_features()

Make sure all samplers have the same sets of features.

check_shared_attr(attr)

Check if all containers have the same value for attr.

enqueue_batch()

Build batch and send to queue.

enqueue_batches()

Callback function for queue thread.

get_batch()

Get batch from queue or directly from a Sampler through sample_batch.

get_container_index()

Get random container index based on weights

get_queue()

Return FIFO queue for storing batches.

get_random_container()

Get random container based on container weights

log_queue_info()

Log info about queue size.

make_mask(high_res)

Make mask for output.

make_output(samples)

post_init_log([args_dict])

Log additional arguments after initialization.

post_proc(samples)

Returns normalized collection of samples / observations along with mask and target output for conditional moment estimation.

preflight()

Run checks before kicking off the queue.

sample_batch()

Get random sampler from collection and return a batch of samples from that sampler.

start()

Start thread to keep sample queue full for batches.

stop()

Stop loading batches.

transform(samples[, smoothing, ...])

Coarsen high res data to get corresponding low res batch.

wrap(data)

Return a Sup3rDataset object or tuple of such.

Attributes

container_weights

Get weights used to sample from different containers based on relative sizes

data

Return underlying data.

features

Get all features contained in data.

hr_shape

Shape of high resolution sample in a low-res / high-res pair.

lr_shape

Shape of low resolution sample in a low-res / high-res pair.

queue_shape

Shape of objects stored in the queue.

queue_thread

Get new queue thread.

running

Boolean to check whether to keep enqueueing batches.

shape

Get shape of underlying data.

make_output(samples)[source]#
Returns:

HR**2 (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features) HR is high-res

class Batch(low_res, high_res)#

Bases: tuple

Create new instance of Batch(low_res, high_res)

__add__(value, /)#

Return self+value.

__mul__(value, /)#

Return self*value.

count(value, /)#

Return number of occurrences of value.

high_res#

Alias for field number 1

index(value, start=0, stop=9223372036854775807, /)#

Return first index of value.

Raises ValueError if the value is not present.

low_res#

Alias for field number 0

class ConditionalBatch(low_res, high_res, output, mask)#

Bases: tuple

Create new instance of ConditionalBatch(low_res, high_res, output, mask)

__add__(value, /)#

Return self+value.

__mul__(value, /)#

Return self*value.

count(value, /)#

Return number of occurrences of value.

high_res#

Alias for field number 1

index(value, start=0, stop=9223372036854775807, /)#

Return first index of value.

Raises ValueError if the value is not present.

low_res#

Alias for field number 0

mask#

Alias for field number 3

output#

Alias for field number 2

check_enhancement_factors()#

Make sure the enhancement factors evenly divide the sample_shape.

check_features()#

Make sure all samplers have the same sets of features.

check_shared_attr(attr)#

Check if all containers have the same value for attr. If they do the collection effectively inherits those attributes.

property container_weights#

Get weights used to sample from different containers based on relative sizes

property data#

Return underlying data.

Returns:

Sup3rDataset

See also

wrap()

enqueue_batch()#

Build batch and send to queue.

enqueue_batches() None#

Callback function for queue thread. While training, the queue is checked for empty spots and filled. In the training thread, batches are removed from the queue.

property features#

Get all features contained in data.

get_batch() Batch#

Get batch from queue or directly from a Sampler through sample_batch.

get_container_index()#

Get random container index based on weights

get_queue()#

Return FIFO queue for storing batches.

get_random_container()#

Get random container based on container weights

property hr_shape#

Shape of high resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))

log_queue_info()#

Log info about queue size.

property lr_shape#

Shape of low resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))

make_mask(high_res)#

Make mask for output. This is used to ensure consistency when training conditional moments.

Note

Consider the case of learning E(HR|LR) where HR is the high_res and LR is the low_res. In theory, the conditional moment estimation works if the full LR is passed as input and predicts the full HR. In practice, only the LR data that overlaps and surrounds the HR data is useful, ie E(HR|LR) = E(HR|LR_nei) where LR_nei is the LR data that surrounds the HR data. Physically, this is equivalent to saying that data far away from a region of interest does not matter. This allows learning the conditional moments on spatial and temporal chunks only if one restricts the high_res output as being overlapped and surrounded by the input low_res. The role of the mask is to ensure that the input low_res always surrounds the output high_res.

Parameters:

high_res (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)

Returns:

mask (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)

post_init_log(args_dict=None)#

Log additional arguments after initialization.

post_proc(samples)#

Returns normalized collection of samples / observations along with mask and target output for conditional moment estimation. Performs coarsening on high-res data if Collection consists of Sampler objects and not DualSampler objects

Returns:

namedtuple – Named tuple with low_res, high_res, mask, and output attributes

preflight()#

Run checks before kicking off the queue.

property queue_shape#

Shape of objects stored in the queue.

property queue_thread#

Get new queue thread.

property running#

Boolean to check whether to keep enqueueing batches.

sample_batch()#

Get random sampler from collection and return a batch of samples from that sampler.

Notes

These samples are wrapped in an np.asarray call, so they have been loaded into memory.

property shape#

Get shape of underlying data.

start() None#

Start thread to keep sample queue full for batches.

stop() None#

Stop loading batches.

transform(samples, smoothing=None, smoothing_ignore=None, temporal_coarsening_method='subsample')#

Coarsen high res data to get corresponding low res batch.

Parameters:
  • samples (Union[np.ndarray, da.core.Array]) – High resolution batch of samples. 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)

  • smoothing (float | None) – Standard deviation to use for gaussian filtering of the coarse data. This can be tuned by matching the kinetic energy of a low resolution simulation with the kinetic energy of a coarsened and smoothed high resolution simulation. If None no smoothing is performed.

  • smoothing_ignore (list | None) – List of features to ignore for the smoothing filter. None will smooth all features if smoothing kwarg is not None

  • temporal_coarsening_method (str) – Method to use for temporal coarsening. Can be subsample, average, min, max, or total

Returns:

  • low_res (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)

  • high_res (Union[np.ndarray, da.core.Array]) – 4D | 5D array (batch_size, spatial_1, spatial_2, features) (batch_size, spatial_1, spatial_2, temporal, features)

wrap(data)#

Return a Sup3rDataset object or tuple of such. This is a tuple when the .data attribute belongs to a Collection object like BatchHandler. Otherwise this is Sup3rDataset object, which is either a wrapped 2-tuple or 1-tuple (e.g. len(data) == 2 or len(data) == 1). This is a 2-tuple when .data belongs to a dual container object like DualSampler and a 1-tuple otherwise.