sup3r.models.dc.Sup3rGanDC

Contents

sup3r.models.dc.Sup3rGanDC#

class Sup3rGanDC(gen_layers, disc_layers, loss='MeanSquaredError', optimizer=None, learning_rate=0.0001, optimizer_disc=None, learning_rate_disc=None, history=None, meta=None, means=None, stdevs=None, default_device=None, name=None)[source]#

Bases: Sup3rGan

Data-centric model using loss across time bins to select training observations

Parameters:
  • gen_layers (list | str) – Hidden layers input argument to phygnn.base.CustomNetwork for the generative super resolving model. Can also be a str filepath to a .json config file containing the input layers argument or a .pkl for a saved pre-trained model.

  • disc_layers (list | str) – Hidden layers input argument to phygnn.base.CustomNetwork for the discriminative model (spatial or spatiotemporal discriminator). Can also be a str filepath to a .json config file containing the input layers argument or a .pkl for a saved pre-trained model.

  • loss (str | dict) – Loss function class name from sup3r.utilities.loss_metrics (prioritized) or tensorflow.keras.losses. Defaults to tf.keras.losses.MeanSquaredError. This can be provided as a dict with kwargs for loss functions with extra parameters. e.g. {‘SpatialExtremesLoss’: {‘weight’: 0.5}}

  • optimizer (tf.keras.optimizers.Optimizer | dict | None | str) – Instantiated tf.keras.optimizers object or a dict optimizer config from tf.keras.optimizers.get_config(). None defaults to Adam.

  • learning_rate (float, optional) – Optimizer learning rate. Not used if optimizer input arg is a pre-initialized object or if optimizer input arg is a config dict.

  • optimizer_disc (tf.keras.optimizers.Optimizer | dict | None) – Same as optimizer input, but if specified this makes a different optimizer just for the discriminator network (spatial or spatiotemporal disc).

  • learning_rate_disc (float, optional) – Same as learning_rate input, but if specified this makes a different learning_rate just for the discriminator network (spatial or spatiotemporal disc).

  • history (pd.DataFrame | str | None) – Model training history with “epoch” index, str pointing to a saved history csv file with “epoch” as first column, or None for clean history

  • meta (dict | None) – Model meta data that describes how the model was created.

  • means (dict | None) – Set of mean values for data normalization keyed by feature name. Can be used to maintain a consistent normalization scheme between transfer learning domains.

  • stdevs (dict | None) – Set of stdev values for data normalization keyed by feature name. Can be used to maintain a consistent normalization scheme between transfer learning domains.

  • default_device (str | None) – Option for default device placement of model weights. If None and a single GPU exists, that GPU will be the default device. If None and multiple GPUs exist, the first GPU will be the default device (this was tested as most efficient given the custom multi-gpu

    strategy developed in self.run_gradient_descent()). Examples:

    “/gpu:0” or “/cpu:0”

  • name (str | None) – Optional name for the GAN.

Methods

calc_loss(hi_res_true, hi_res_gen[, ...])

Calculate the GAN loss function using generated and true high resolution data.

calc_loss_disc(disc_out_true, disc_out_gen)

Calculate the loss term for the discriminator model (either the spatial or temporal discriminator).

calc_loss_gen_advers(disc_out_gen)

Calculate the adversarial component of the loss term for the generator model.

calc_loss_gen_content(hi_res_true, hi_res_gen)

Calculate the content loss term for the generator model.

calc_val_loss(batch_handler, ...)

Overloading the base calc_val_loss method.

calc_val_loss_gen(batch_handler, ...)

Calculate the validation total loss across the validation samples.

calc_val_loss_gen_content(batch_handler)

Calculate the validation content loss across the validation samples.

check_batch_handler_attrs(batch_handler)

Not all batch handlers have the following attributes.

dict_to_tensorboard(entry)

Write data to tensorboard log file.

discriminate(hi_res[, norm_in])

Run the discriminator model on a hi resolution input field.

early_stop(history, column[, threshold, n_epoch])

Determine whether to stop training early based on nearly no change to validation loss for a certain number of consecutive epochs.

finish_epoch(epoch, epochs, t0, ...[, extras])

Perform finishing checks after an epoch is done training

generate(low_res[, norm_in, un_norm_out, ...])

Use the generator model to generate high res data from low res input.

get_high_res_exo_input(high_res)

Get exogenous feature data from high_res

get_loss_fun(loss)

Get the initialized loss function class from the sup3r loss library or the tensorflow losses.

get_optimizer_config(optimizer)

Get a config that defines the current model optimizer

get_optimizer_state(optimizer)

Get a set of state variables for the optimizer

get_s_enhance_from_layers()

Compute factor by which model will enhance spatial resolution from layer attributes.

get_single_grad(low_res, hi_res_true, ...[, ...])

Run gradient descent for one mini-batch of (low_res, hi_res_true), do not update weights, just return gradient details.

get_t_enhance_from_layers()

Compute factor by which model will enhance temporal resolution from layer attributes.

get_weight_update_fraction(history, ...[, ...])

Get the factor by which to multiply previous adversarial loss weight

init_optimizer(optimizer, learning_rate)

Initialize keras optimizer object.

init_weights(lr_shape, hr_shape[, device])

Initialize the generator and discriminator weights with device placement.

load(model_dir[, verbose])

Load the GAN with its sub-networks from a previously saved-to output directory.

load_network(model, name)

Load a CustomNetwork object from hidden layers config, .json file config, or .pkl file saved pre-trained model.

load_saved_params(out_dir[, verbose])

Load saved model_params (you need this and the gen+disc models to load a full model).

log_loss_details(loss_details[, level])

Log the loss details to the module logger.

norm_input(low_res)

Normalize low resolution data being input to the generator.

profile_to_tensorboard(name)

Write profile data to tensorboard log file.

run_gradient_descent(low_res, hi_res_true, ...)

Run gradient descent for one mini-batch of (low_res, hi_res_true) and update weights

save(out_dir)

Save the GAN with its sub-networks to a directory.

save_params(out_dir)

seed([s])

Set the random seed for reproducible results.

set_model_params(**kwargs)

Set parameters used for training the model

set_norm_stats(new_means, new_stdevs)

Set the normalization statistics associated with a data batch handler to model attributes.

train(batch_handler, input_resolution, n_epoch)

Train the GAN model on real low res data and real high res data

train_epoch(batch_handler, ...[, multi_gpu])

Train the GAN for one epoch.

un_norm_output(output)

Un-normalize synthetically generated output data to physical units

update_adversarial_weights(history, ...)

Update spatial / temporal adversarial loss weights based on training fraction history.

update_loss_details(loss_details, new_data, ...)

Update a dictionary of loss_details with loss information from a new batch.

update_optimizer([option])

Update optimizer by changing current configuration

Attributes

discriminator

Get the discriminator model.

discriminator_weights

Get a list of layer weights and bias terms for the discriminator model.

generator

Get the generative model.

generator_weights

Get a list of layer weights and bias terms for the generator model.

history

Model training history DataFrame (None if not yet trained)

hr_exo_features

Get list of high-resolution exogenous filter names the model uses.

hr_out_features

Get the list of high-resolution output feature names that the generative model outputs.

input_dims

Get dimension of model generator input.

input_resolution

Resolution of input data.

is_4d

Check if model expects spatial only input

is_5d

Check if model expects spatiotemporal input

lr_features

Get a list of low-resolution features input to the generative model.

means

Get the data normalization mean values.

meta

Get meta data dictionary that defines how the model was created

model_params

Model parameters, used to save model to disc

optimizer

Get the tensorflow optimizer to perform gradient descent calculations for the generative network.

optimizer_disc

Get the tensorflow optimizer to perform gradient descent calculations for the discriminator network.

output_resolution

Resolution of output data.

s_enhance

Factor by which model will enhance spatial resolution.

s_enhancements

List of spatial enhancement factors.

smoothed_features

Get the list of smoothed input feature names that the generative model was trained on.

smoothing

Value of smoothing parameter used in gaussian filtering of coarsened high res data.

stdevs

Get the data normalization standard deviation values.

t_enhance

Factor by which model will enhance temporal resolution.

t_enhancements

List of temporal enhancement factors.

total_batches

Record of total number of batches for logging.

version_record

A record of important versions that this model was built with.

weights

Get a list of all the layer weights and bias terms for the generator and discriminator networks

calc_val_loss_gen(batch_handler, weight_gen_advers)[source]#

Calculate the validation total loss across the validation samples. e.g. If the sample domain has 100 steps and the validation set has 10 bins then this will get a list of losses across step 0 to 10, 10 to 20, etc. Use this to determine performance within bins and to update how observations are selected from these bins.

Parameters:
  • batch_handler (sup3r.preprocessing.BatchHandlerDC) – BatchHandler object to iterate through

  • weight_gen_advers (float) – Weight factor for the adversarial loss component of the generator vs. the discriminator.

Returns:

array – Array of total losses for all sample bins, with shape (n_space_bins, n_time_bins)

calc_val_loss_gen_content(batch_handler)[source]#

Calculate the validation content loss across the validation samples. e.g. If the sample domain has 100 steps and the validation set has 10 bins then this will get a list of losses across step 0 to 10, 10 to 20, etc. Use this to determine performance within bins and to update how observations are selected from these bins.

Parameters:

batch_handler (sup3r.preprocessing.BatchHandlerDC) – BatchHandler object to iterate through

Returns:

list – List of content losses for all sample bins

calc_val_loss(batch_handler, weight_gen_advers, loss_details)[source]#

Overloading the base calc_val_loss method. Method updates the temporal weights for the batch handler based on the losses across the time bins

Parameters:
  • batch_handler (sup3r.preprocessing.BatchHandler) – BatchHandler object to iterate through

  • weight_gen_advers (float) – Weight factor for the adversarial loss component of the generator vs. the discriminator.

  • loss_details (dict) – Namespace of the breakdown of loss components where each value is a running average at the current state in the epoch.

Returns:

dict – Updated loss_details with mean validation loss calculated using the validation samples across the time bins

calc_loss(hi_res_true, hi_res_gen, weight_gen_advers=0.001, train_gen=True, train_disc=False)#

Calculate the GAN loss function using generated and true high resolution data.

Parameters:
  • hi_res_true (tf.Tensor) – Ground truth high resolution spatiotemporal data.

  • hi_res_gen (tf.Tensor) – Superresolved high resolution spatiotemporal data generated by the generative model.

  • weight_gen_advers (float) – Weight factor for the adversarial loss component of the generator vs. the discriminator.

  • train_gen (bool) – True if generator is being trained, then loss=loss_gen

  • train_disc (bool) – True if disc is being trained, then loss=loss_disc

Returns:

  • loss (tf.Tensor) – 0D tensor representing the loss value for the network being trained (either generator or one of the discriminators)

  • loss_details (dict) – Namespace of the breakdown of loss components

static calc_loss_disc(disc_out_true, disc_out_gen)#

Calculate the loss term for the discriminator model (either the spatial or temporal discriminator).

Parameters:
  • disc_out_true (tf.Tensor) – Raw discriminator outputs from the discriminator model predicting only on ground truth data hi_res_true (not on hi_res_gen).

  • disc_out_gen (tf.Tensor) – Raw discriminator outputs from the discriminator model predicting only on synthetic data hi_res_gen (not on hi_res_true).

Returns:

loss_disc (tf.Tensor) – 0D tensor discriminator model loss for either the spatial or temporal component of the super resolution generated output.

static calc_loss_gen_advers(disc_out_gen)#

Calculate the adversarial component of the loss term for the generator model.

Parameters:

disc_out_gen (tf.Tensor) – Raw discriminator outputs from the discriminator model predicting only on hi_res_gen (not on hi_res_true).

Returns:

loss_gen_advers (tf.Tensor) – 0D tensor generator model loss for the adversarial component of the generator loss term.

calc_loss_gen_content(hi_res_true, hi_res_gen)#

Calculate the content loss term for the generator model.

Parameters:
  • hi_res_true (tf.Tensor) – Ground truth high resolution spatiotemporal data.

  • hi_res_gen (tf.Tensor) – Superresolved high resolution spatiotemporal data generated by the generative model.

Returns:

loss_gen_s (tf.Tensor) – 0D tensor generator model loss for the content loss comparing the hi res ground truth to the hi res synthetically generated output.

static check_batch_handler_attrs(batch_handler)#

Not all batch handlers have the following attributes. So we perform some sanitation before sending to set_model_params

dict_to_tensorboard(entry)#

Write data to tensorboard log file. This is usually a loss_details dictionary.

Parameters:

entry (dict) – Dictionary of values to write to tensorboard log file

discriminate(hi_res, norm_in=False)#

Run the discriminator model on a hi resolution input field.

Parameters:
  • hi_res (np.ndarray) – Real or fake high res data in a 4D or 5D tensor: (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features)

  • norm_in (bool) – Flag to normalize low_res input data if the self._means, self._stdevs attributes are available. The disc should always received normalized data with mean=0 stdev=1.

Returns:

out (np.ndarray) – Discriminator output logits

property discriminator#

Get the discriminator model.

Returns:

phygnn.base.CustomNetwork

property discriminator_weights#

Get a list of layer weights and bias terms for the discriminator model.

Returns:

list

static early_stop(history, column, threshold=0.005, n_epoch=5)#

Determine whether to stop training early based on nearly no change to validation loss for a certain number of consecutive epochs.

Parameters:
  • history (pd.DataFrame | None) – Model training history

  • column (str) – Column from the model training history to evaluate for early termination.

  • threshold (float) – The absolute relative fractional difference in validation loss between subsequent epochs below which an early termination is warranted. E.g. if val losses were 0.1 and 0.0998 the relative diff would be calculated as 0.0002 / 0.1 = 0.002 which would be less than the default thresold of 0.01 and would satisfy the condition for early termination.

  • n_epoch (int) – The number of consecutive epochs that satisfy the threshold that warrants an early stop.

Returns:

stop (bool) – Flag to stop training (True) or keep going (False).

finish_epoch(epoch, epochs, t0, loss_details, checkpoint_int, out_dir, early_stop_on, early_stop_threshold, early_stop_n_epoch, extras=None)#

Perform finishing checks after an epoch is done training

Parameters:
  • epoch (int) – Epoch number that is finishing

  • epochs (list) – List of epochs being iterated through

  • t0 (float) – Starting time of training.

  • loss_details (dict) – Namespace of the breakdown of loss components

  • checkpoint_int (int | None) – Epoch interval at which to save checkpoint models.

  • out_dir (str) – Directory to save checkpoint models. Should have {epoch} in the directory name. This directory will be created if it does not already exist.

  • early_stop_on (str | None) – If not None, this should be a column in the training history to evaluate for early stopping (e.g. validation_loss_gen, validation_loss_disc). If this value in this history decreases by an absolute fractional relative difference of less than 0.01 for more than 5 epochs in a row, the training will stop early.

  • early_stop_threshold (float) – The absolute relative fractional difference in validation loss between subsequent epochs below which an early termination is warranted. E.g. if val losses were 0.1 and 0.0998 the relative diff would be calculated as 0.0002 / 0.1 = 0.002 which would be less than the default thresold of 0.01 and would satisfy the condition for early termination.

  • early_stop_n_epoch (int) – The number of consecutive epochs that satisfy the threshold that warrants an early stop.

  • extras (dict | None) – Extra kwargs/parameters to save in the epoch history.

Returns:

stop (bool) – Flag to early stop training.

generate(low_res, norm_in=True, un_norm_out=True, exogenous_data=None)#

Use the generator model to generate high res data from low res input. This is the public generate function.

Parameters:
  • low_res (np.ndarray) – Low-resolution input data, usually a 4D or 5D array of shape: (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features)

  • norm_in (bool) – Flag to normalize low_res input data if the self._means, self._stdevs attributes are available. The generator should always received normalized data with mean=0 stdev=1. This also normalizes hi_res_topo.

  • un_norm_out (bool) – Flag to un-normalize synthetically generated output data to physical units

  • exogenous_data (dict | ExoData | None) – Special dictionary (class:ExoData) of exogenous feature data with entries describing whether features should be combined at input, a mid network layer, or with output. This doesn’t have to include the ‘model’ key since this data is for a single step model.

Returns:

hi_res (ndarray) – Synthetically generated high-resolution data, usually a 4D or 5D array with shape: (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features)

property generator#

Get the generative model.

Returns:

phygnn.base.CustomNetwork

property generator_weights#

Get a list of layer weights and bias terms for the generator model.

Returns:

list

get_high_res_exo_input(high_res)#

Get exogenous feature data from high_res

Parameters:

high_res (tf.Tensor) – Ground truth high resolution spatiotemporal data.

Returns:

exo_data (dict) – Dictionary of exogenous feature data used as input to tf_generate. e.g. {'topography': tf.Tensor(...)}

static get_loss_fun(loss)#

Get the initialized loss function class from the sup3r loss library or the tensorflow losses.

Parameters:

loss (str | dict) – Loss function class name from sup3r.utilities.loss_metrics (prioritized) or tensorflow.keras.losses. Defaults to tf.keras.losses.MeanSquaredError. This can be provided as a dict with kwargs for loss functions with extra parameters. e.g. {‘SpatialExtremesLoss’: {‘weight’: 0.5}}

Returns:

out (tf.keras.losses.Loss) – Initialized loss function class that is callable, e.g. if “MeanSquaredError” is requested, this will return an instance of tf.keras.losses.MeanSquaredError()

static get_optimizer_config(optimizer)#

Get a config that defines the current model optimizer

Parameters:

optimizer (tf.keras.optimizers.Optimizer) – TF-Keras optimizer object (e.g., Adam)

Returns:

config (dict) – Optimizer config

classmethod get_optimizer_state(optimizer)#

Get a set of state variables for the optimizer

Parameters:

optimizer (tf.keras.optimizers.Optimizer) – TF-Keras optimizer object (e.g., Adam)

Returns:

state (dict) – Optimizer state variables

get_s_enhance_from_layers()#

Compute factor by which model will enhance spatial resolution from layer attributes. Used in model training during high res coarsening

get_single_grad(low_res, hi_res_true, training_weights, device_name=None, **calc_loss_kwargs)#

Run gradient descent for one mini-batch of (low_res, hi_res_true), do not update weights, just return gradient details.

Parameters:
  • low_res (np.ndarray) – Real low-resolution data in a 4D or 5D array: (n_observations, spatial_1, spatial_2, features) (n_observations, spatial_1, spatial_2, temporal, features)

  • hi_res_true (np.ndarray) – Real high-resolution data in a 4D or 5D array: (n_observations, spatial_1, spatial_2, features) (n_observations, spatial_1, spatial_2, temporal, features)

  • training_weights (list) – A list of layer weights that are to-be-trained based on the current loss weight values.

  • device_name (None | str) – Optional tensorflow device name for GPU placement. Note that if a GPU is available, variables will be placed on that GPU even if device_name=None.

  • calc_loss_kwargs (dict) – Kwargs to pass to the self.calc_loss() method

Returns:

  • grad (list) – a list or nested structure of Tensors (or IndexedSlices, or None, or CompositeTensor) representing the gradients for the training_weights

  • loss_details (dict) – Namespace of the breakdown of loss components

get_t_enhance_from_layers()#

Compute factor by which model will enhance temporal resolution from layer attributes. Used in model training during high res coarsening

static get_weight_update_fraction(history, comparison_key, update_bounds=(0.5, 0.95), update_frac=0.0)#

Get the factor by which to multiply previous adversarial loss weight

Parameters:
  • history (dict) – Dictionary with information on how often discriminators were trained during previous epoch.

  • comparison_key (str) – history key to use for update check

  • update_bounds (tuple) – Tuple specifying allowed range for history[comparison_key]. If history[comparison_key] < update_bounds[0] then the weight will be increased by (1 + update_frac). If history[comparison_key] > update_bounds[1] then the weight will be decreased by 1 / (1 + update_frac).

  • update_frac (float) – Fraction by which to increase/decrease adversarial loss weight

Returns:

float – Factor by which to multiply old weight to get updated weight

property history#

Model training history DataFrame (None if not yet trained)

Returns:

pandas.DataFrame | None

property hr_exo_features#

Get list of high-resolution exogenous filter names the model uses. If the model has N concat or add layers this list will be the last N features in the training features list. The ordering is assumed to be the same as the order of concat or add layers. If training features is […, topo, sza], and the model has 2 concat or add layers, exo features will be [topo, sza]. Topo will then be used in the first concat layer and sza will be used in the second

property hr_out_features#

Get the list of high-resolution output feature names that the generative model outputs.

static init_optimizer(optimizer, learning_rate)#

Initialize keras optimizer object.

Parameters:
  • optimizer (tf.keras.optimizers.Optimizer | dict | None | str) – Instantiated tf.keras.optimizers object or a dict optimizer config from tf.keras.optimizers.get_config(). None defaults to Adam.

  • learning_rate (float, optional) – Optimizer learning rate. Not used if optimizer input arg is a pre-initialized object or if optimizer input arg is a config dict.

Returns:

optimizer (tf.keras.optimizers.Optimizer) – Initialized optimizer object.

init_weights(lr_shape, hr_shape, device=None)#

Initialize the generator and discriminator weights with device placement.

Parameters:
  • lr_shape (tuple) – Shape of one batch of low res input data for sup3r resolution. Note that the batch size (axis=0) must be included, but the actual batch size doesnt really matter.

  • hr_shape (tuple) – Shape of one batch of high res input data for sup3r resolution. Note that the batch size (axis=0) must be included, but the actual batch size doesnt really matter.

  • device (str | None) – Option to place model weights on a device. If None, self.default_device will be used.

property input_dims#

Get dimension of model generator input. This is usually 4D for spatial models and 5D for spatiotemporal models. This gives the input to the first step if the model is multi-step. Returns 5 for linear models.

Returns:

int

property input_resolution#

Resolution of input data. Given as a dictionary {'spatial': '...km', 'temporal': '...min'}. The numbers are required to be integers in the units specified. The units are not strict as long as the resolution of the exogenous data, when extracting exogenous data, is specified in the same units.

property is_4d#

Check if model expects spatial only input

property is_5d#

Check if model expects spatiotemporal input

classmethod load(model_dir, verbose=True)#

Load the GAN with its sub-networks from a previously saved-to output directory.

Parameters:
  • model_dir (str) – Directory to load GAN model files from.

  • verbose (bool) – Flag to log information about the loaded model.

Returns:

out (BaseModel) – Returns a pretrained gan model that was previously saved to out_dir

load_network(model, name)#

Load a CustomNetwork object from hidden layers config, .json file config, or .pkl file saved pre-trained model.

Parameters:
  • model (str | dict) – Model hidden layers config, a .json with “hidden_layers” key, or a .pkl for a saved pre-trained model.

  • name (str) – Name of the model to be loaded

Returns:

model (phygnn.CustomNetwork) – CustomNetwork object initialized from the model input.

static load_saved_params(out_dir, verbose=True)#

Load saved model_params (you need this and the gen+disc models to load a full model).

Parameters:
  • out_dir (str) – Directory to load model files from.

  • verbose (bool) – Flag to log information about the loaded model.

Returns:

params (dict) – Model parameters loaded from disk json file. This should be the same as self.model_params with and additional ‘history’ entry. Should be all the kwargs you need to init a model.

static log_loss_details(loss_details, level='INFO')#

Log the loss details to the module logger.

Parameters:
  • loss_details (dict) – Namespace of the breakdown of loss components where each value is a running average at the current state in the epoch.

  • level (str) – Log level (e.g. INFO, DEBUG)

property lr_features#

Get a list of low-resolution features input to the generative model. This includes low-resolution features that might be supplied exogenously at inference time but that were in the low-res batches during training

property means#

Get the data normalization mean values.

Returns:

np.ndarray

property meta#

Get meta data dictionary that defines how the model was created

property model_params#

Model parameters, used to save model to disc

Returns:

dict

norm_input(low_res)#

Normalize low resolution data being input to the generator.

Parameters:

low_res (np.ndarray) – Un-normalized low-resolution input data in physical units, usually a 4D or 5D array of shape: (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features)

Returns:

low_res (np.ndarray) – Normalized low-resolution input data, usually a 4D or 5D array of shape: (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features)

property optimizer#

Get the tensorflow optimizer to perform gradient descent calculations for the generative network. This is functionally identical to optimizer_disc is no special optimizer model or learning rate was specified for the disc.

Returns:

tf.keras.optimizers.Optimizer

property optimizer_disc#

Get the tensorflow optimizer to perform gradient descent calculations for the discriminator network.

Returns:

tf.keras.optimizers.Optimizer

property output_resolution#

Resolution of output data. Given as a dictionary {‘spatial’: ‘…km’, ‘temporal’: ‘…min’}. This is computed from the input resolution and the enhancement factors.

profile_to_tensorboard(name)#

Write profile data to tensorboard log file.

Parameters:

name (str) – Tag name to use for profile info

run_gradient_descent(low_res, hi_res_true, training_weights, optimizer=None, multi_gpu=False, **calc_loss_kwargs)#

Run gradient descent for one mini-batch of (low_res, hi_res_true) and update weights

Parameters:
  • low_res (np.ndarray) – Real low-resolution data in a 4D or 5D array: (n_observations, spatial_1, spatial_2, features) (n_observations, spatial_1, spatial_2, temporal, features)

  • hi_res_true (np.ndarray) – Real high-resolution data in a 4D or 5D array: (n_observations, spatial_1, spatial_2, features) (n_observations, spatial_1, spatial_2, temporal, features)

  • training_weights (list) – A list of layer weights that are to-be-trained based on the current loss weight values.

  • optimizer (tf.keras.optimizers.Optimizer) – Optimizer class to use to update weights. This can be different if you’re training just the generator or one of the discriminator models. Defaults to the generator optimizer.

  • multi_gpu (bool) – Flag to break up the batch for parallel gradient descent calculations on multiple gpus. If True and multiple GPUs are present, each batch from the batch_handler will be divided up between the GPUs and resulting gradients from each GPU will be summed and then applied once per batch at the nominal learning rate that the model and optimizer were initialized with.

  • calc_loss_kwargs (dict) – Kwargs to pass to the self.calc_loss() method

Returns:

loss_details (dict) – Namespace of the breakdown of loss components

property s_enhance#

Factor by which model will enhance spatial resolution. Used in model training during high res coarsening and also in forward pass routine to determine shape of needed exogenous data

property s_enhancements#

List of spatial enhancement factors. In the case of a single step model this is just [self.s_enhance]. This is used to determine shapes of needed exogenous data in forward pass routine

save(out_dir)#

Save the GAN with its sub-networks to a directory.

Parameters:

out_dir (str) – Directory to save GAN model files. This directory will be created if it does not already exist.

save_params(out_dir)#
Parameters:

out_dir (str) – Directory to save linear model params. This directory will be created if it does not already exist.

static seed(s=0)#

Set the random seed for reproducible results.

Parameters:

s (int) – Random seed

set_model_params(**kwargs)#

Set parameters used for training the model

Parameters:

kwargs (dict) – Keyword arguments including ‘input_resolution’, ‘lr_features’, ‘hr_exo_features’, ‘hr_out_features’, ‘smoothed_features’, ‘s_enhance’, ‘t_enhance’, ‘smoothing’

set_norm_stats(new_means, new_stdevs)#

Set the normalization statistics associated with a data batch handler to model attributes.

Parameters:
  • new_means (dict | None) – Set of mean values for data normalization keyed by feature name. Can be used to maintain a consistent normalization scheme between transfer learning domains.

  • new_stdevs (dict | None) – Set of stdev values for data normalization keyed by feature name. Can be used to maintain a consistent normalization scheme between transfer learning domains.

property smoothed_features#

Get the list of smoothed input feature names that the generative model was trained on.

property smoothing#

Value of smoothing parameter used in gaussian filtering of coarsened high res data.

property stdevs#

Get the data normalization standard deviation values.

Returns:

np.ndarray

property t_enhance#

Factor by which model will enhance temporal resolution. Used in model training during high res coarsening and also in forward pass routine to determine shape of needed exogenous data

property t_enhancements#

List of temporal enhancement factors. In the case of a single step model this is just [self.t_enhance]. This is used to determine shapes of needed exogenous data in forward pass routine

property total_batches#

Record of total number of batches for logging.

train(batch_handler, input_resolution, n_epoch, weight_gen_advers=0.001, train_gen=True, train_disc=True, disc_loss_bounds=(0.45, 0.6), checkpoint_int=None, out_dir='./gan_{epoch}', early_stop_on=None, early_stop_threshold=0.005, early_stop_n_epoch=5, adaptive_update_bounds=(0.9, 0.99), adaptive_update_fraction=0.0, multi_gpu=False, tensorboard_log=True, tensorboard_profile=False)#

Train the GAN model on real low res data and real high res data

Parameters:
  • batch_handler (sup3r.preprocessing.BatchHandler) – BatchHandler object to iterate through

  • input_resolution (dict) – Dictionary specifying spatiotemporal input resolution. e.g. {‘temporal’: ‘60min’, ‘spatial’: ‘30km’}

  • n_epoch (int) – Number of epochs to train on

  • weight_gen_advers (float) – Weight factor for the adversarial loss component of the generator vs. the discriminator.

  • train_gen (bool) – Flag whether to train the generator for this set of epochs

  • train_disc (bool) – Flag whether to train the discriminator for this set of epochs

  • disc_loss_bounds (tuple) – Lower and upper bounds for the discriminator loss outside of which the discriminator will not train unless train_disc=True and train_gen=False.

  • checkpoint_int (int | None) – Epoch interval at which to save checkpoint models.

  • out_dir (str) – Directory to save checkpoint GAN models. Should have {epoch} in the directory name. This directory will be created if it does not already exist.

  • early_stop_on (str | None) – If not None, this should be a column in the training history to evaluate for early stopping (e.g. validation_loss_gen, validation_loss_disc). If this value in this history decreases by an absolute fractional relative difference of less than 0.01 for more than 5 epochs in a row, the training will stop early.

  • early_stop_threshold (float) – The absolute relative fractional difference in validation loss between subsequent epochs below which an early termination is warranted. E.g. if val losses were 0.1 and 0.0998 the relative diff would be calculated as 0.0002 / 0.1 = 0.002 which would be less than the default thresold of 0.01 and would satisfy the condition for early termination.

  • early_stop_n_epoch (int) – The number of consecutive epochs that satisfy the threshold that warrants an early stop.

  • adaptive_update_bounds (tuple) – Tuple specifying allowed range for loss_details[comparison_key]. If history[comparison_key] < threshold_range[0] then the weight will be increased by (1 + update_frac). If history[comparison_key] > threshold_range[1] then the weight will be decreased by 1 / (1 + update_frac).

  • adaptive_update_fraction (float) – Amount by which to increase or decrease adversarial weights for adaptive updates

  • multi_gpu (bool) – Flag to break up the batch for parallel gradient descent calculations on multiple gpus. If True and multiple GPUs are present, each batch from the batch_handler will be divided up between the GPUs and resulting gradients from each GPU will be summed and then applied once per batch at the nominal learning rate that the model and optimizer were initialized with. If true and multiple gpus are found, default_device device should be set to /gpu:0

  • tensorboard_log (bool) – Whether to write log file for use with tensorboard. Log data can be viewed with tensorboard --logdir <logdir> where <logdir> is the parent directory of out_dir, and pointing the browser to the printed address.

  • tensorboard_profile (bool) – Whether to export profiling information to tensorboard. This can then be viewed in the tensorboard dashboard under the profile tab

  • TODO ((1) args here are getting excessive. Might be time for some)

  • refactoring.

  • (2) cal_val_loss should be done in a separate thread from train_epoch

  • so they can be done concurrently. This would be especially important

  • for batch handlers which require val data, like dc handlers.

  • (3) Would like an automatic way to exit the batch handler thread

  • instead of manually calling .stop() here.

train_epoch(batch_handler, weight_gen_advers, train_gen, train_disc, disc_loss_bounds, multi_gpu=False)#

Train the GAN for one epoch.

Parameters:
  • batch_handler (sup3r.preprocessing.BatchHandler) – BatchHandler object to iterate through

  • weight_gen_advers (float) – Weight factor for the adversarial loss component of the generator vs. the discriminator.

  • train_gen (bool) – Flag whether to train the generator for this set of epochs

  • train_disc (bool) – Flag whether to train the discriminator for this set of epochs

  • disc_loss_bounds (tuple) – Lower and upper bounds for the discriminator loss outside of which the discriminators will not train unless train_disc=True or and train_gen=False.

  • multi_gpu (bool) – Flag to break up the batch for parallel gradient descent calculations on multiple gpus. If True and multiple GPUs are present, each batch from the batch_handler will be divided up between the GPUs and resulting gradients from each GPU will be summed and then applied once per batch at the nominal learning rate that the model and optimizer were initialized with. If true and multiple gpus are found, default_device device should be set to /gpu:0

Returns:

loss_details (dict) – Namespace of the breakdown of loss components

un_norm_output(output)#

Un-normalize synthetically generated output data to physical units

Parameters:

output (tf.Tensor | np.ndarray) – Synthetically generated high-resolution data

Returns:

output (np.ndarray) – Synthetically generated high-resolution data

update_adversarial_weights(history, adaptive_update_fraction, adaptive_update_bounds, weight_gen_advers, train_disc)#

Update spatial / temporal adversarial loss weights based on training fraction history.

Parameters:
  • history (dict) – Dictionary with information on how often discriminators were trained during current and previous epochs.

  • adaptive_update_fraction (float) – Amount by which to increase or decrease adversarial loss weights for adaptive updates

  • adaptive_update_bounds (tuple) – Tuple specifying allowed range for history[comparison_key]. If history[comparison_key] < update_bounds[0] then the weight will be increased by (1 + update_frac). If history[comparison_key] > update_bounds[1] then the weight will be decreased by 1 / (1 + update_frac).

  • weight_gen_advers (float) – Weight factor for the adversarial loss component of the generator vs. the discriminator.

  • train_disc (bool) – Whether the discriminator was set to be trained during the previous epoch

Returns:

weight_gen_advers (float) – Updated weight factor for the adversarial loss component of the generator vs. the discriminator.

static update_loss_details(loss_details, new_data, batch_len, prefix=None)#

Update a dictionary of loss_details with loss information from a new batch.

Parameters:
  • loss_details (dict) – Namespace of the breakdown of loss components where each value is a running average at the current state in the epoch.

  • new_data (dict) – Namespace of the breakdown of loss components for a single new batch.

  • batch_len (int) – Length of the incoming batch.

  • prefix (None | str) – Option to prefix the names of the loss data when saving to the loss_details dictionary.

Returns:

loss_details (dict) – Same as input loss_details but with running averages updated.

update_optimizer(option='generator', **kwargs)#

Update optimizer by changing current configuration

Parameters:
  • option (str) – Which optimizer to update. Can be “generator”, “discriminator”, or “all”

  • kwargs (dict) – kwargs to use for optimizer configuration update

property version_record#

A record of important versions that this model was built with.

Returns:

dict

property weights#

Get a list of all the layer weights and bias terms for the generator and discriminator networks