sup3r.pipeline.strategy.ForwardPassStrategy#
- class ForwardPassStrategy(file_paths: str | list | Path, model_kwargs: dict, fwp_chunk_shape: tuple = (None, None, None), spatial_pad: int = 0, temporal_pad: int = 0, model_class: str = 'Sup3rGan', out_pattern: str | None = None, input_handler_name: str | None = None, input_handler_kwargs: dict | None = None, exo_handler_kwargs: dict | None = None, bias_correct_method: str | None = None, bias_correct_kwargs: dict | None = None, allowed_const: list | bool | None = None, incremental: bool = True, output_workers: int = 1, invert_uv: bool | None = None, pass_workers: int = 1, max_nodes: int = 1, head_node: bool = False)[source]#
Bases:
object
Class to prepare data for forward passes through generator.
A full file list of contiguous times is provided. The corresponding data is split into spatiotemporal chunks which can overlap in time and space. These chunks are distributed across nodes according to the max nodes input or number of temporal chunks. This strategy stores information on these chunks, how they overlap, how they are distributed to nodes, and how to crop generator output to stich the chunks back together.
Use the following inputs to initialize data handlers on different nodes and to define the size of the data chunks that will be passed through the generator.
- Parameters:
file_paths (list | str) – A list of low-resolution source files to extract raster data from. Each file must have the same number of timesteps. Can also pass a string with a unix-style file path which will be passed through glob.glob.
Note: These files can also include a 2D (lat, lon) “mask” variable which is True for grid points which can be skipped in the forward pass and False otherwise. This will be used to skip running the forward pass for chunks which only include masked points. e.g. chunks covering only ocean. Chunks with even a single unmasked point will still be sent through the forward pass.
model_kwargs (str | list) – Keyword arguments to send to
model_class.load(**model_kwargs)
to initialize the GAN. Typically this is just the string path to the model directory, but can be multiple models or arguments for more complex models.fwp_chunk_shape (tuple) – Max shape (spatial_1, spatial_2, temporal) of an unpadded coarse chunk to use for a forward pass. The number of nodes that the
ForwardPassStrategy
is set to distribute to is calculated by dividing up the total time index from all file_paths by the temporal part of this chunk shape. Each node will then be parallelized across parallel processes by the spatial chunk shape. If temporal_pad / spatial_pad are non zero the chunk sent to the generator can be bigger than this shape. If running in serial set this equal to the shape of the full spatiotemporal data volume for best performance.spatial_pad (int) – Size of spatial overlap between coarse chunks passed to forward passes for subsequent spatial stitching. This overlap will pad both sides of the fwp_chunk_shape.
temporal_pad (int) – Size of temporal overlap between coarse chunks passed to forward passes for subsequent temporal stitching. This overlap will pad both sides of the fwp_chunk_shape.
model_class (str) – Name of the sup3r model class for the GAN model to load. The default is the basic spatial / spatiotemporal
Sup3rGan
model. This will be loaded fromsup3r.models
out_pattern (str) – Output file pattern. Must include {file_id} format key. Each output file will have a unique file_id filled in and the ext determines the output type. If pattern is None then data will be returned in an array and not saved.
input_handler_name (str | None) – Class to use for input data. Provide a string name to match an rasterizer or handler class in
sup3r.preprocessing
input_handler_kwargs (dict | None) – Any kwargs for initializing the
input_handler_name
class.exo_handler_kwargs (dict | None) – Dictionary of args to pass to
ExoDataHandler
for extracting exogenous features for foward passes. This should be a nested dictionary with keys for each exogenous feature. The dictionaries corresponding to the feature names should include the path to exogenous data source and the files used for input to the forward passes, at minimum. Can also provide a dictionary ofinput_handler_kwargs
used for the handler which opens the exogenous data. e.g.:{'topography': { 'source_file': ..., 'input_files': ..., 'input_handler_kwargs': {'target': ..., 'shape': ...}}}
bias_correct_method (str | None) – Optional bias correction function name that can be imported from the
sup3r.bias.bias_transforms
module. This will transform the source data according to some predefined bias correction transformation along with the bias_correct_kwargs. As the first argument, this method must receive a generic numpy array of data to be bias correctedbias_correct_kwargs (dict | None) – Optional namespace of kwargs to provide to bias_correct_method. If this is provided, it must be a dictionary where each key is a feature name and each value is a dictionary of kwargs to correct that feature. You can bias correct only certain input features by only including those feature names in this dict.
allowed_const (list | bool) – Tensorflow has a tensor memory limit of 2GB (result of protobuf limitation) and when exceeded can return a tensor with a constant output. sup3r will raise a
MemoryError
in response. If your model is allowed to output a constant output, set this to True to allow any constant output or a list of allowed possible constant outputs. For example, a precipitation model should be allowed to output all zeros so set this to[0]
. For details on this limit: tensorflow/tensorflow#51870incremental (bool) – Allow the forward pass iteration to skip spatiotemporal chunks that already have an output file (default = True) or iterate through all chunks and overwrite any pre-existing outputs (False).
output_workers (int | None) – Max number of workers to use for writing forward pass output.
invert_uv (bool | None) – Whether to convert u and v wind components to windspeed and direction for writing to output. This defaults to True for H5 output and False for NETCDF output.
pass_workers (int | None) – Max number of workers to use for performing forward passes on a single node. If 1 then all forward passes on chunks distributed to a single node will be run serially. pass_workers=2 is the minimum number of workers required to run the ForwardPass initialization and
run_chunk()
methods concurrently.max_nodes (int | None) – Maximum number of nodes to distribute spatiotemporal chunks across. If None then a node will be used for each temporal chunk.
head_node (bool) – Whether initialization is taking place on the head node of a multi node job launch. When this is true
ForwardPassStrategy
is only partially initialized to provide the head node enough information for how to distribute jobs across nodes. Preflight tasks like bias correction will be skipped because they will be performed on the nodes jobs are distributed to by the head node.
Methods
chunk_finished
(chunk_idx[, log])Check if process for given chunk_index has already been run.
chunk_masked
(chunk_idx[, log])Check if the region for this chunk is masked.
get_chunk_indices
(chunk_index)Get (spatial, temporal) indices for the given chunk index
get_exo_cache_files
(model)Get list of exo cache files so we can check if they exist or not.
get_exo_kwargs
(model)Get list of exo kwargs for all exo features.
init_chunk
([chunk_index])Get
FowardPassChunk
instance for the given chunk index.Get input handler instance for given input kwargs.
load_exo_data
(model)Extract exogenous data for each exo feature and store data in dictionary with key for each exo feature
node_finished
(node_idx)Check if all out files for a given node have been saved
Prelight logging and sanity checks
prep_chunk_data
([chunk_index])Get low res input data and exo data for given chunk index and bias correct low res data if requested.
Attributes
allowed_const
bias_correct_kwargs
bias_correct_method
exo_handler_kwargs
fwp_chunk_shape
Cached spatial mask which returns whether a given spatial chunk should be skipped by the forward pass or not.
head_node
Get high resolution lat lons
incremental
input_handler_kwargs
input_handler_name
invert_uv
max_nodes
Meta data dictionary for the strategy.
model_class
Get array of lists such that node_chunks[i] is a list of indices for the chunks that will be sent through the generator on the ith node.
Get list of output file names for each file chunk forward pass.
out_pattern
output_workers
pass_workers
spatial_pad
temporal_pad
List of chunk indices that are not masked from the input spatial region.
file_paths
model_kwargs
- property meta#
Meta data dictionary for the strategy. Used to add info to forward pass output meta.
- init_input_handler()[source]#
Get input handler instance for given input kwargs. If self.head_node is False we get all requested features. Otherwise this is part of initialization on a head node and just used to get the shape of the input domain, so we don’t need to get any features yet.
- property node_chunks#
Get array of lists such that node_chunks[i] is a list of indices for the chunks that will be sent through the generator on the ith node.
- property unmasked_chunks#
List of chunk indices that are not masked from the input spatial region. These chunks are those that will go through the forward pass. Masked chunks will be skipped.
- property hr_lat_lon#
Get high resolution lat lons
- property out_files#
Get list of output file names for each file chunk forward pass.
- prep_chunk_data(chunk_index=0)[source]#
Get low res input data and exo data for given chunk index and bias correct low res data if requested.
Note
input_data.load()
is called here to load chunk data into memory
- init_chunk(chunk_index=0)[source]#
Get
FowardPassChunk
instance for the given chunk index.This selects the appropriate data from self.input_handler and self.exo_data and returns a structure object (ForwardPassChunk) with that data and other chunk specific attributes.
- get_exo_cache_files(model)[source]#
Get list of exo cache files so we can check if they exist or not.
- load_exo_data(model)[source]#
Extract exogenous data for each exo feature and store data in dictionary with key for each exo feature
- Returns:
exo_data (ExoData) –
ExoData
object composed of multipleSingleExoDataStep
objects. This is the exo data for the full spatiotemporal extent.
- property fwp_mask#
Cached spatial mask which returns whether a given spatial chunk should be skipped by the forward pass or not. This is used to skip running the forward pass for area with just ocean, for example.
Note: This is True for grid points which can be skipped in the forward pass and False otherwise.