sup3r.preprocessing.conditional_moment_batch_handling.BatchHandlerMom1

class BatchHandlerMom1(data_handlers, batch_size=8, s_enhance=3, t_enhance=1, means=None, stds=None, norm=True, n_batches=10, temporal_coarsening_method='subsample', temporal_enhancing_method='constant', stdevs_file=None, means_file=None, overwrite_stats=False, smoothing=None, smoothing_ignore=None, stats_workers=None, norm_workers=None, load_workers=None, max_workers=None, model_mom1=None, s_padding=None, t_padding=None, end_t_padding=False)[source]

Bases: BatchHandler

Sup3r base batch handling class

Parameters:
  • data_handlers (list[DataHandler]) – List of DataHandler instances

  • batch_size (int) – Number of observations in a batch

  • s_enhance (int) – Factor by which to coarsen spatial dimensions of the high resolution data to generate low res data

  • t_enhance (int) – Factor by which to coarsen temporal dimension of the high resolution data to generate low res data

  • means (np.ndarray) – dimensions (features) array of means for all features with same ordering as data features. If not None and norm is True these will be used for normalization

  • stds (np.ndarray) – dimensions (features) array of means for all features with same ordering as data features. If not None and norm is True these will be used form normalization

  • norm (bool) – Whether to normalize the data or not

  • n_batches (int) – Number of batches in an epoch, this sets the iteration limit for this object.

  • temporal_coarsening_method (str) – [subsample, average, total] Subsample will take every t_enhance-th time step, average will average over t_enhance time steps, total will sum over t_enhance time steps

  • temporal_enhancing_method (str) – [constant, linear] Method to enhance temporally when constructing subfilter. At every temporal location, a low-res temporal data is substracted from the high-res temporal data predicted. constant will assume that the low-res temporal data is constant between landmarks. linear will linearly interpolate between landmarks to generate the low-res data to remove from the high-res.

  • stdevs_file (str | None) – Path to stdevs data or where to save data after calling get_stats

  • means_file (str | None) – Path to means data or where to save data after calling get_stats

  • overwrite_stats (bool) – Whether to overwrite stats cache files.

  • smoothing (float | None) – Standard deviation to use for gaussian filtering of the coarse data. This can be tuned by matching the kinetic energy of a low resolution simulation with the kinetic energy of a coarsened and smoothed high resolution simulation. If None no smoothing is performed.

  • smoothing_ignore (list | None) – List of features to ignore for the smoothing filter. None will smooth all features if smoothing kwarg is not None

  • max_workers (int | None) – Providing a value for max workers will be used to set the value of norm_workers, stats_workers, and load_workers. If max_workers == 1 then all processes will be serialized. If None stats_workers, load_workers, and norm_workers will use their own provided values.

  • load_workers (int | None) – max number of workers to use for loading data handlers.

  • norm_workers (int | None) – max number of workers to use for normalizing data handlers.

  • stats_workers (int | None) – max number of workers to use for computing stats across data handlers.

  • model_mom1 (Sup3rCondMom | None) – model that predicts the first conditional moments. Useful to prepare data for learning second conditional moment.

  • s_padding (int | None) – Width of spatial padding to predict only middle part. If None, no padding is used

  • t_padding (int | None) – Width of temporal padding to predict only middle part. If None, no padding is used

  • end_t_padding (bool | False) – Zero pad the end of temporal space. Ensures that loss is calculated only if snapshot is surrounded by temporal landmarks. False by default

Methods

cache_stats()

Saved stdevs and means to cache files if files are not None

check_cached_stats()

Get standard deviations and means for all data features from cache files if available.

get_handler_index()

Get random handler index based on handler weights

get_rand_handler()

Get random handler based on handler weights

get_stats()

Get standard deviations and means for all data features

load_handler_data()

Load data handler data in parallel or serial

normalize([means, stds])

Compute means and stds for each feature across all datasets and normalize each data handler dataset.

Attributes

DATA_HANDLER_CLASS

feature_mem

Get memory used by each feature in data handlers

features

Get the ordered list of feature names held in this object's data handlers

handler_weights

Get weights used to sample from different data handlers based on relative sizes

hr_exo_features

Get a list of high-resolution features that are only used for training e.g., mid-network high-res topo injection.

hr_features_ind

Get the high-resolution feature channel indices that should be included for training.

hr_out_features

Get a list of low-resolution features that are intended to be output by the GAN.

load_workers

Get max workers for loading data handler based on memory usage

lr_features

Get a list of low-resolution features.

norm_workers

Get max workers used for calculating and normalization across features

shape

Shape of full dataset across all handlers

stats_workers

Get max workers for calculating stats based on memory usage

VAL_CLASS

alias of ValidationDataMom1

BATCH_CLASS

alias of BatchMom1

cache_stats()

Saved stdevs and means to cache files if files are not None

check_cached_stats()

Get standard deviations and means for all data features from cache files if available.

Returns:

  • means (dict | none) – Dictionary of means for all features with keys: feature names and values: mean values. if None, this will be calculated. if norm is true these will be used for data normalization

  • stds (dict | none) – dictionary of standard deviation values for all features with keys: feature names and values: standard deviations. if None, this will be calculated. if norm is true these will be used for data normalization

property feature_mem

Get memory used by each feature in data handlers

property features

Get the ordered list of feature names held in this object’s data handlers

get_handler_index()

Get random handler index based on handler weights

get_rand_handler()

Get random handler based on handler weights

get_stats()

Get standard deviations and means for all data features

property handler_weights

Get weights used to sample from different data handlers based on relative sizes

property hr_exo_features

Get a list of high-resolution features that are only used for training e.g., mid-network high-res topo injection.

property hr_features_ind

Get the high-resolution feature channel indices that should be included for training. Any high-resolution features that are only included in the data handler to be coarsened for the low-res input are removed

property hr_out_features

Get a list of low-resolution features that are intended to be output by the GAN.

load_handler_data()

Load data handler data in parallel or serial

property load_workers

Get max workers for loading data handler based on memory usage

property lr_features

Get a list of low-resolution features. All low-resolution features are used for training.

property norm_workers

Get max workers used for calculating and normalization across features

normalize(means=None, stds=None)

Compute means and stds for each feature across all datasets and normalize each data handler dataset. Checks if input means and stds are different from stored means and stds and renormalizes if they are

Parameters:
  • means (dict | none) – Dictionary of means for all features with keys: feature names and values: mean values. if None, this will be calculated. if norm is true these will be used for data normalization

  • stds (dict | none) – dictionary of standard deviation values for all features with keys: feature names and values: standard deviations. if None, this will be calculated. if norm is true these will be used for data normalization

  • features (list | None) – Optional list of features used to index data array during normalization. If this is None self.features will be used.

property shape

Shape of full dataset across all handlers

Returns:

shape (tuple) – (spatial_1, spatial_2, temporal, features) With spatiotemporal extent equal to the sum across all data handler dimensions

property stats_workers

Get max workers for calculating stats based on memory usage