sup3r.preprocessing.conditional_moment_batch_handling.BatchHandlerMom2SF
- class BatchHandlerMom2SF(data_handlers, batch_size=8, s_enhance=3, t_enhance=1, means=None, stds=None, norm=True, n_batches=10, temporal_coarsening_method='subsample', temporal_enhancing_method='constant', stdevs_file=None, means_file=None, overwrite_stats=False, smoothing=None, smoothing_ignore=None, stats_workers=None, norm_workers=None, load_workers=None, max_workers=None, model_mom1=None, s_padding=None, t_padding=None, end_t_padding=False)[source]
Bases:
BatchHandlerMom1
Sup3r batch handling class for second conditional moment of subfilter velocity
- Parameters:
data_handlers (list[DataHandler]) – List of DataHandler instances
batch_size (int) – Number of observations in a batch
s_enhance (int) – Factor by which to coarsen spatial dimensions of the high resolution data to generate low res data
t_enhance (int) – Factor by which to coarsen temporal dimension of the high resolution data to generate low res data
means (np.ndarray) – dimensions (features) array of means for all features with same ordering as data features. If not None and norm is True these will be used for normalization
stds (np.ndarray) – dimensions (features) array of means for all features with same ordering as data features. If not None and norm is True these will be used form normalization
norm (bool) – Whether to normalize the data or not
n_batches (int) – Number of batches in an epoch, this sets the iteration limit for this object.
temporal_coarsening_method (str) – [subsample, average, total] Subsample will take every t_enhance-th time step, average will average over t_enhance time steps, total will sum over t_enhance time steps
temporal_enhancing_method (str) – [constant, linear] Method to enhance temporally when constructing subfilter. At every temporal location, a low-res temporal data is substracted from the high-res temporal data predicted. constant will assume that the low-res temporal data is constant between landmarks. linear will linearly interpolate between landmarks to generate the low-res data to remove from the high-res.
stdevs_file (str | None) – Path to stdevs data or where to save data after calling get_stats
means_file (str | None) – Path to means data or where to save data after calling get_stats
overwrite_stats (bool) – Whether to overwrite stats cache files.
smoothing (float | None) – Standard deviation to use for gaussian filtering of the coarse data. This can be tuned by matching the kinetic energy of a low resolution simulation with the kinetic energy of a coarsened and smoothed high resolution simulation. If None no smoothing is performed.
smoothing_ignore (list | None) – List of features to ignore for the smoothing filter. None will smooth all features if smoothing kwarg is not None
max_workers (int | None) – Providing a value for max workers will be used to set the value of norm_workers, stats_workers, and load_workers. If max_workers == 1 then all processes will be serialized. If None stats_workers, load_workers, and norm_workers will use their own provided values.
load_workers (int | None) – max number of workers to use for loading data handlers.
norm_workers (int | None) – max number of workers to use for normalizing data handlers.
stats_workers (int | None) – max number of workers to use for computing stats across data handlers.
model_mom1 (Sup3rCondMom | None) – model that predicts the first conditional moments. Useful to prepare data for learning second conditional moment.
s_padding (int | None) – Width of spatial padding to predict only middle part. If None, no padding is used
t_padding (int | None) – Width of temporal padding to predict only middle part. If None, no padding is used
end_t_padding (bool | False) – Zero pad the end of temporal space. Ensures that loss is calculated only if snapshot is surrounded by temporal landmarks. False by default
Methods
Saved stdevs and means to cache files if files are not None
Get standard deviations and means for all data features from cache files if available.
Get random handler index based on handler weights
Get random handler based on handler weights
Get standard deviations and means for all data features
Load data handler data in parallel or serial
normalize
([means, stds])Compute means and stds for each feature across all datasets and normalize each data handler dataset.
Attributes
DATA_HANDLER_CLASS
Get memory used by each feature in data handlers
Get the ordered list of feature names held in this object's data handlers
Get weights used to sample from different data handlers based on relative sizes
Get a list of high-resolution features that are only used for training e.g., mid-network high-res topo injection.
Get the high-resolution feature channel indices that should be included for training.
Get a list of low-resolution features that are intended to be output by the GAN.
Get max workers for loading data handler based on memory usage
Get a list of low-resolution features.
Get max workers used for calculating and normalization across features
Shape of full dataset across all handlers
Get max workers for calculating stats based on memory usage
- VAL_CLASS
alias of
ValidationDataMom2SF
- BATCH_CLASS
alias of
BatchMom2SF
- cache_stats()
Saved stdevs and means to cache files if files are not None
- check_cached_stats()
Get standard deviations and means for all data features from cache files if available.
- Returns:
means (dict | none) – Dictionary of means for all features with keys: feature names and values: mean values. if None, this will be calculated. if norm is true these will be used for data normalization
stds (dict | none) – dictionary of standard deviation values for all features with keys: feature names and values: standard deviations. if None, this will be calculated. if norm is true these will be used for data normalization
- property feature_mem
Get memory used by each feature in data handlers
- property features
Get the ordered list of feature names held in this object’s data handlers
- get_handler_index()
Get random handler index based on handler weights
- get_rand_handler()
Get random handler based on handler weights
- get_stats()
Get standard deviations and means for all data features
- property handler_weights
Get weights used to sample from different data handlers based on relative sizes
- property hr_exo_features
Get a list of high-resolution features that are only used for training e.g., mid-network high-res topo injection.
- property hr_features_ind
Get the high-resolution feature channel indices that should be included for training. Any high-resolution features that are only included in the data handler to be coarsened for the low-res input are removed
- property hr_out_features
Get a list of low-resolution features that are intended to be output by the GAN.
- load_handler_data()
Load data handler data in parallel or serial
- property load_workers
Get max workers for loading data handler based on memory usage
- property lr_features
Get a list of low-resolution features. All low-resolution features are used for training.
- property norm_workers
Get max workers used for calculating and normalization across features
- normalize(means=None, stds=None)
Compute means and stds for each feature across all datasets and normalize each data handler dataset. Checks if input means and stds are different from stored means and stds and renormalizes if they are
- Parameters:
means (dict | none) – Dictionary of means for all features with keys: feature names and values: mean values. if None, this will be calculated. if norm is true these will be used for data normalization
stds (dict | none) – dictionary of standard deviation values for all features with keys: feature names and values: standard deviations. if None, this will be calculated. if norm is true these will be used for data normalization
features (list | None) – Optional list of features used to index data array during normalization. If this is None self.features will be used.
- property shape
Shape of full dataset across all handlers
- Returns:
shape (tuple) – (spatial_1, spatial_2, temporal, features) With spatiotemporal extent equal to the sum across all data handler dimensions
- property stats_workers
Get max workers for calculating stats based on memory usage