sup3r.preprocessing.batch_queues.abstract.AbstractBatchQueue#

class AbstractBatchQueue(samplers: List[Sampler] | List[DualSampler], batch_size: int = 16, n_batches: int = 64, s_enhance: int = 1, t_enhance: int = 1, queue_cap: int | None = None, transform_kwargs: dict | None = None, max_workers: int = 1, thread_name: str = 'training', mode: str = 'lazy')[source]#

Bases: Collection, ABC

Abstract BatchQueue class. This class gets batches from a dataset generator and maintains a queue of batches in a dedicated thread so the training routine can proceed as soon as batches are available.

Parameters:
  • samplers (List[Sampler]) – List of Sampler instances

  • batch_size (int) – Number of observations / samples in a batch

  • n_batches (int) – Number of batches in an epoch, this sets the iteration limit for this object.

  • s_enhance (int) – Integer factor by which the spatial axes is to be enhanced.

  • t_enhance (int) – Integer factor by which the temporal axes is to be enhanced.

  • queue_cap (int) – Maximum number of batches the batch queue can store.

  • transform_kwargs (Union[Dict, None]) – Dictionary of kwargs to be passed to self.transform. This method performs smoothing / coarsening.

  • max_workers (int) – Number of workers / threads to use for getting batches to fill queue

  • thread_name (str) – Name of the queue thread. Default is ‘training’. Used to set name to ‘validation’ for BatchHandler, which has a training and validation queue.

  • mode (str) – Loading mode. Default is ‘lazy’, which only loads data into memory as batches are queued. ‘eager’ will load all data into memory right away.

Methods

check_enhancement_factors()

Make sure the enhancement factors evenly divide the sample_shape.

check_features()

Make sure all samplers have the same sets of features.

check_shared_attr(attr)

Check if all containers have the same value for attr.

enqueue_batch()

Build batch and send to queue.

enqueue_batches()

Callback function for queue thread.

get_batch()

Get batch from queue or directly from a Sampler through sample_batch.

get_container_index()

Get random container index based on weights

get_queue()

Return FIFO queue for storing batches.

get_random_container()

Get random container based on container weights

log_queue_info()

Log info about queue size.

post_init_log([args_dict])

Log additional arguments after initialization.

post_proc(samples)

Performs some post proc on dequeued samples before sending out for training.

preflight()

Run checks before kicking off the queue.

sample_batch()

Get random sampler from collection and return a batch of samples from that sampler.

start()

Start thread to keep sample queue full for batches.

stop()

Stop loading batches.

transform(samples, **kwargs)

Apply transform on batch samples.

wrap(data)

Return a Sup3rDataset object or tuple of such.

Attributes

container_weights

Get weights used to sample from different containers based on relative sizes

data

Return underlying data.

features

Get all features contained in data.

hr_shape

Shape of high resolution sample in a low-res / high-res pair.

lr_shape

Shape of low resolution sample in a low-res / high-res pair.

queue_shape

Shape of objects stored in the queue. e.g. for single dataset queues this is (batch_size, *sample_shape, len(features)). For dual dataset queues this is [(batch_size, *lr_shape), (batch_size, *hr_shape)].

queue_thread

Get new queue thread.

running

Boolean to check whether to keep enqueueing batches.

shape

Get shape of underlying data.

class Batch(low_res, high_res)#

Bases: tuple

Create new instance of Batch(low_res, high_res)

__add__(value, /)#

Return self+value.

__mul__(value, /)#

Return self*value.

count(value, /)#

Return number of occurrences of value.

high_res#

Alias for field number 1

index(value, start=0, stop=9223372036854775807, /)#

Return first index of value.

Raises ValueError if the value is not present.

low_res#

Alias for field number 0

abstract property queue_shape#

Shape of objects stored in the queue. e.g. for single dataset queues this is (batch_size, *sample_shape, len(features)). For dual dataset queues this is [(batch_size, *lr_shape), (batch_size, *hr_shape)]

get_queue()[source]#

Return FIFO queue for storing batches.

preflight()[source]#

Run checks before kicking off the queue.

property queue_thread#

Get new queue thread.

check_features()[source]#

Make sure all samplers have the same sets of features.

check_enhancement_factors()[source]#

Make sure the enhancement factors evenly divide the sample_shape.

abstract transform(samples, **kwargs)[source]#

Apply transform on batch samples. This can include smoothing / coarsening depending on the type of queue. e.g. coarsening could be included for a single dataset queue where low res samples are coarsened high res samples. For a dual dataset queue this will just include smoothing.

post_proc(samples) Batch[source]#

Performs some post proc on dequeued samples before sending out for training. Post processing can include coarsening on high-res data (if Collection consists of Sampler objects and not DualSampler objects), smoothing, etc

Returns:

Batch (namedtuple) – namedtuple with low_res and high_res attributes

start() None[source]#

Start thread to keep sample queue full for batches.

stop() None[source]#

Stop loading batches.

get_batch() Batch[source]#

Get batch from queue or directly from a Sampler through sample_batch.

property running#

Boolean to check whether to keep enqueueing batches.

enqueue_batches() None[source]#

Callback function for queue thread. While training, the queue is checked for empty spots and filled. In the training thread, batches are removed from the queue.

get_container_index()[source]#

Get random container index based on weights

get_random_container()[source]#

Get random container based on container weights

sample_batch()[source]#

Get random sampler from collection and return a batch of samples from that sampler.

Notes

These samples are wrapped in an np.asarray call, so they have been loaded into memory.

log_queue_info()[source]#

Log info about queue size.

enqueue_batch()[source]#

Build batch and send to queue.

property lr_shape#

Shape of low resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))

property hr_shape#

Shape of high resolution sample in a low-res / high-res pair. (e.g. (spatial_1, spatial_2, temporal, features))

check_shared_attr(attr)#

Check if all containers have the same value for attr. If they do the collection effectively inherits those attributes.

property container_weights#

Get weights used to sample from different containers based on relative sizes

property data#

Return underlying data.

Returns:

Sup3rDataset

See also

wrap()

property features#

Get all features contained in data.

post_init_log(args_dict=None)#

Log additional arguments after initialization.

property shape#

Get shape of underlying data.

wrap(data)#

Return a Sup3rDataset object or tuple of such. This is a tuple when the .data attribute belongs to a Collection object like BatchHandler. Otherwise this is Sup3rDataset object, which is either a wrapped 2-tuple or 1-tuple (e.g. len(data) == 2 or len(data) == 1). This is a 2-tuple when .data belongs to a dual container object like DualSampler and a 1-tuple otherwise.