sup3r.pipeline.strategy.ForwardPassStrategy#

class ForwardPassStrategy(file_paths: str | list | Path, model_kwargs: dict, fwp_chunk_shape: tuple = (None, None, None), spatial_pad: int = 0, temporal_pad: int = 0, model_class: str = 'Sup3rGan', out_pattern: str | None = None, input_handler_name: str | None = None, input_handler_kwargs: dict | None = None, exo_handler_kwargs: dict | None = None, bias_correct_method: str | None = None, bias_correct_kwargs: dict | None = None, allowed_const: list | bool | None = None, incremental: bool = True, output_workers: int = 1, invert_uv: bool | None = None, pass_workers: int = 1, max_nodes: int = 1, head_node: bool = False)[source]#

Bases: object

Class to prepare data for forward passes through generator.

A full file list of contiguous times is provided. The corresponding data is split into spatiotemporal chunks which can overlap in time and space. These chunks are distributed across nodes according to the max nodes input or number of temporal chunks. This strategy stores information on these chunks, how they overlap, how they are distributed to nodes, and how to crop generator output to stich the chunks back together.

Use the following inputs to initialize data handlers on different nodes and to define the size of the data chunks that will be passed through the generator.

Parameters:
  • file_paths (list | str) – A list of low-resolution source files to extract raster data from. Each file must have the same number of timesteps. Can also pass a string with a unix-style file path which will be passed through glob.glob.

    Note: These files can also include a 2D (lat, lon) “mask” variable which is True for grid points which can be skipped in the forward pass and False otherwise. This will be used to skip running the forward pass for chunks which only include masked points. e.g. chunks covering only ocean. Chunks with even a single unmasked point will still be sent through the forward pass.

  • model_kwargs (str | list) – Keyword arguments to send to model_class.load(**model_kwargs) to initialize the GAN. Typically this is just the string path to the model directory, but can be multiple models or arguments for more complex models.

  • fwp_chunk_shape (tuple) – Max shape (spatial_1, spatial_2, temporal) of an unpadded coarse chunk to use for a forward pass. The number of nodes that the ForwardPassStrategy is set to distribute to is calculated by dividing up the total time index from all file_paths by the temporal part of this chunk shape. Each node will then be parallelized across parallel processes by the spatial chunk shape. If temporal_pad / spatial_pad are non zero the chunk sent to the generator can be bigger than this shape. If running in serial set this equal to the shape of the full spatiotemporal data volume for best performance.

  • spatial_pad (int) – Size of spatial overlap between coarse chunks passed to forward passes for subsequent spatial stitching. This overlap will pad both sides of the fwp_chunk_shape.

  • temporal_pad (int) – Size of temporal overlap between coarse chunks passed to forward passes for subsequent temporal stitching. This overlap will pad both sides of the fwp_chunk_shape.

  • model_class (str) – Name of the sup3r model class for the GAN model to load. The default is the basic spatial / spatiotemporal Sup3rGan model. This will be loaded from sup3r.models

  • out_pattern (str) – Output file pattern. Must include {file_id} format key. Each output file will have a unique file_id filled in and the ext determines the output type. If pattern is None then data will be returned in an array and not saved.

  • input_handler_name (str | None) – Class to use for input data. Provide a string name to match an rasterizer or handler class in sup3r.preprocessing

  • input_handler_kwargs (dict | None) – Any kwargs for initializing the input_handler_name class.

  • exo_handler_kwargs (dict | None) – Dictionary of args to pass to ExoDataHandler for extracting exogenous features for foward passes. This should be a nested dictionary with keys for each exogenous feature. The dictionaries corresponding to the feature names should include the path to exogenous data source and the files used for input to the forward passes, at minimum. Can also provide a dictionary of input_handler_kwargs used for the handler which opens the exogenous data. e.g.:

    {'topography': {
        'source_file': ...,
        'input_files': ...,
        'input_handler_kwargs': {'target': ..., 'shape': ...}}}
    
  • bias_correct_method (str | None) – Optional bias correction function name that can be imported from the sup3r.bias.bias_transforms module. This will transform the source data according to some predefined bias correction transformation along with the bias_correct_kwargs. As the first argument, this method must receive a generic numpy array of data to be bias corrected

  • bias_correct_kwargs (dict | None) – Optional namespace of kwargs to provide to bias_correct_method. If this is provided, it must be a dictionary where each key is a feature name and each value is a dictionary of kwargs to correct that feature. You can bias correct only certain input features by only including those feature names in this dict.

  • allowed_const (list | bool) – Tensorflow has a tensor memory limit of 2GB (result of protobuf limitation) and when exceeded can return a tensor with a constant output. sup3r will raise a MemoryError in response. If your model is allowed to output a constant output, set this to True to allow any constant output or a list of allowed possible constant outputs. For example, a precipitation model should be allowed to output all zeros so set this to [0]. For details on this limit: tensorflow/tensorflow#51870

  • incremental (bool) – Allow the forward pass iteration to skip spatiotemporal chunks that already have an output file (default = True) or iterate through all chunks and overwrite any pre-existing outputs (False).

  • output_workers (int | None) – Max number of workers to use for writing forward pass output.

  • invert_uv (bool | None) – Whether to convert u and v wind components to windspeed and direction for writing to output. This defaults to True for H5 output and False for NETCDF output.

  • pass_workers (int | None) – Max number of workers to use for performing forward passes on a single node. If 1 then all forward passes on chunks distributed to a single node will be run serially. pass_workers=2 is the minimum number of workers required to run the ForwardPass initialization and run_chunk() methods concurrently.

  • max_nodes (int | None) – Maximum number of nodes to distribute spatiotemporal chunks across. If None then a node will be used for each temporal chunk.

  • head_node (bool) – Whether initialization is taking place on the head node of a multi node job launch. When this is true ForwardPassStrategy is only partially initialized to provide the head node enough information for how to distribute jobs across nodes. Preflight tasks like bias correction will be skipped because they will be performed on the nodes jobs are distributed to by the head node.

Methods

chunk_finished(chunk_idx[, log])

Check if process for given chunk_index has already been run.

chunk_masked(chunk_idx[, log])

Check if the region for this chunk is masked.

get_chunk_indices(chunk_index)

Get (spatial, temporal) indices for the given chunk index

get_exo_cache_files(model)

Get list of exo cache files so we can check if they exist or not.

get_exo_kwargs(model)

Get list of exo kwargs for all exo features.

get_pad_width(chunk_index)

Get padding for the current spatiotemporal chunk

init_chunk([chunk_index])

Get FowardPassChunk instance for the given chunk index.

init_input_handler()

Get input handler instance for given input kwargs.

load_exo_data(model)

Extract exogenous data for each exo feature and store data in dictionary with key for each exo feature

node_finished(node_idx)

Check if all out files for a given node have been saved

preflight()

Prelight logging and sanity checks

prep_chunk_data([chunk_index])

Get low res input data and exo data for given chunk index and bias correct low res data if requested.

Attributes

allowed_const

bias_correct_kwargs

bias_correct_method

exo_handler_kwargs

fwp_chunk_shape

fwp_mask

Cached spatial mask which returns whether a given spatial chunk should be skipped by the forward pass or not.

head_node

hr_lat_lon

Get high resolution lat lons

incremental

input_handler_kwargs

input_handler_name

invert_uv

max_nodes

meta

Meta data dictionary for the strategy.

model_class

node_chunks

Get array of lists such that node_chunks[i] is a list of indices for the chunks that will be sent through the generator on the ith node.

out_files

Get list of output file names for each file chunk forward pass.

out_pattern

output_workers

pass_workers

spatial_pad

temporal_pad

unmasked_chunks

List of chunk indices that are not masked from the input spatial region.

file_paths

model_kwargs

property meta#

Meta data dictionary for the strategy. Used to add info to forward pass output meta.

init_input_handler()[source]#

Get input handler instance for given input kwargs. If self.head_node is False we get all requested features. Otherwise this is part of initialization on a head node and just used to get the shape of the input domain, so we don’t need to get any features yet.

property node_chunks#

Get array of lists such that node_chunks[i] is a list of indices for the chunks that will be sent through the generator on the ith node.

property unmasked_chunks#

List of chunk indices that are not masked from the input spatial region. These chunks are those that will go through the forward pass. Masked chunks will be skipped.

preflight()[source]#

Prelight logging and sanity checks

get_chunk_indices(chunk_index)[source]#

Get (spatial, temporal) indices for the given chunk index

property hr_lat_lon#

Get high resolution lat lons

property out_files#

Get list of output file names for each file chunk forward pass.

get_pad_width(chunk_index)[source]#

Get padding for the current spatiotemporal chunk

Returns:

padding (tuple) – Tuple of tuples with padding width for spatial and temporal dimensions. Each tuple includes the start and end of padding for that dimension. Ordering is spatial_1, spatial_2, temporal.

prep_chunk_data(chunk_index=0)[source]#

Get low res input data and exo data for given chunk index and bias correct low res data if requested.

Note

input_data.load() is called here to load chunk data into memory

init_chunk(chunk_index=0)[source]#

Get FowardPassChunk instance for the given chunk index.

This selects the appropriate data from self.input_handler and self.exo_data and returns a structure object (ForwardPassChunk) with that data and other chunk specific attributes.

get_exo_kwargs(model)[source]#

Get list of exo kwargs for all exo features.

get_exo_cache_files(model)[source]#

Get list of exo cache files so we can check if they exist or not.

load_exo_data(model)[source]#

Extract exogenous data for each exo feature and store data in dictionary with key for each exo feature

Returns:

exo_data (ExoData) – ExoData object composed of multiple SingleExoDataStep objects. This is the exo data for the full spatiotemporal extent.

property fwp_mask#

Cached spatial mask which returns whether a given spatial chunk should be skipped by the forward pass or not. This is used to skip running the forward pass for area with just ocean, for example.

Note: This is True for grid points which can be skipped in the forward pass and False otherwise.

node_finished(node_idx)[source]#

Check if all out files for a given node have been saved

chunk_finished(chunk_idx, log=True)[source]#

Check if process for given chunk_index has already been run. Considered finished if there is already an output file and incremental is False.

chunk_masked(chunk_idx, log=True)[source]#

Check if the region for this chunk is masked. This is used to skip running the forward pass for region with just ocean, for example.