"""Slicer class for chunking forward pass input"""
import itertools as it
import logging
from dataclasses import dataclass
from typing import Union
from warnings import warn
import numpy as np
from sup3r.pipeline.utilities import (
get_chunk_slices,
)
from sup3r.preprocessing.utilities import _parse_time_slice, log_args
logger = logging.getLogger(__name__)
[docs]
@dataclass
class ForwardPassSlicer:
"""Get slices for sending data chunks through generator.
Parameters
----------
coarse_shape : tuple
Shape of full domain for low res data
time_steps : int
Number of time steps for full temporal domain of low res data. This
is used to construct a dummy_time_index from np.arange(time_steps)
time_slice : slice | list
Slice to use to extract range from time_index. Can be a ``slice(start,
stop, step)`` or list ``[start, stop, step]``
chunk_shape : tuple
Max shape (spatial_1, spatial_2, temporal) of an unpadded coarse
chunk to use for a forward pass. The number of nodes that the
ForwardPassStrategy is set to distribute to is calculated by
dividing up the total time index from all file_paths by the
temporal part of this chunk shape. Each node will then be
parallelized accross parallel processes by the spatial chunk shape.
If temporal_pad / spatial_pad are non zero the chunk sent
to the generator can be bigger than this shape. If running in
serial set this equal to the shape of the full spatiotemporal data
volume for best performance.
s_enhance : int
Spatial enhancement factor
t_enhance : int
Temporal enhancement factor
spatial_pad : int
Size of spatial overlap between coarse chunks passed to forward
passes for subsequent spatial stitching. This overlap will pad both
sides of the fwp_chunk_shape. Note that the first and last chunks
in any of the spatial dimension will not be padded.
temporal_pad : int
Size of temporal overlap between coarse chunks passed to forward
passes for subsequent temporal stitching. This overlap will pad
both sides of the fwp_chunk_shape. Note that the first and last
chunks in the temporal dimension will not be padded.
"""
coarse_shape: Union[tuple, list]
time_steps: int
s_enhance: int
t_enhance: int
time_slice: slice
temporal_pad: int
spatial_pad: int
chunk_shape: Union[tuple, list]
@log_args
def __post_init__(self):
self.dummy_time_index = np.arange(self.time_steps)
self.time_slice = _parse_time_slice(self.time_slice)
self._chunk_lookup = None
self._extra_padding = None
self._s1_lr_slices = None
self._s2_lr_slices = None
self._s1_lr_pad_slices = None
self._s2_lr_pad_slices = None
self._s1_hr_crop_slices = None
self._s2_hr_crop_slices = None
self._s_lr_slices = None
self._s_lr_pad_slices = None
self._s_lr_crop_slices = None
self._t_lr_pad_slices = None
self._t_lr_crop_slices = None
self._s_hr_slices = None
self._s_hr_crop_slices = None
self._t_hr_crop_slices = None
self._hr_crop_slices = None
[docs]
def get_spatial_slices(self):
"""Get spatial slices for small data chunks that are passed through
generator
Returns
-------
s_lr_slices: list
List of slices for low res data chunks which have not been padded.
data_handler.data[s_lr_slice] corresponds to an unpadded low res
input to the model.
s_lr_pad_slices : list
List of slices which have been padded so that high res output
can be stitched together. data_handler.data[s_lr_pad_slice]
corresponds to a padded low res input to the model.
s_hr_slices : list
List of slices for high res data corresponding to the
lr_slices regions. output_array[s_hr_slice] corresponds to the
cropped generator output.
"""
return (self.s_lr_slices, self.s_lr_pad_slices, self.s_hr_slices)
[docs]
def get_time_slices(self):
"""Calculate the number of time chunks across the full time index
Returns
-------
t_lr_slices : list
List of low-res non-padded time index slices. e.g. If
fwp_chunk_size[2] is 5 then the size of these slices will always
be 5.
t_lr_pad_slices : list
List of low-res padded time index slices. e.g. If fwp_chunk_size[2]
is 5 the size of these slices will be 15, with exceptions at the
start and end of the full time index.
"""
return self.t_lr_slices, self.t_lr_pad_slices
@property
def s_lr_slices(self):
"""Get low res spatial slices for small data chunks that are passed
through generator
Returns
-------
_s_lr_slices : list
List of spatial slices corresponding to the unpadded spatial region
going through the generator
"""
if self._s_lr_slices is None:
self._s_lr_slices = list(
it.product(self.s1_lr_slices, self.s2_lr_slices)
)
return self._s_lr_slices
@property
def s_lr_pad_slices(self):
"""Get low res padded slices for small data chunks that are passed
through generator
Returns
-------
_s_lr_pad_slices : list
List of slices which have been padded so that high res output
can be stitched together. Each entry in this list has a slice for
each spatial dimension. data_handler.data[s_lr_pad_slice] gives the
padded data volume passed through the generator
"""
if self._s_lr_pad_slices is None:
self._s_lr_pad_slices = list(
it.product(self.s1_lr_pad_slices, self.s2_lr_pad_slices)
)
return self._s_lr_pad_slices
@property
def t_lr_pad_slices(self):
"""Get low res temporal padded slices for distributing time chunks
across nodes. These slices correspond to the time chunks sent to each
node and are padded according to temporal_pad.
Returns
-------
_t_lr_pad_slices : list
List of low res temporal slices which have been padded so that high
res output can be stitched together
"""
if self._t_lr_pad_slices is None:
self._t_lr_pad_slices = self.get_padded_slices(
slices=self.t_lr_slices,
shape=self.time_steps,
enhancement=1,
padding=self.temporal_pad,
step=self.time_slice.step,
)
return self._t_lr_pad_slices
@property
def t_lr_crop_slices(self):
"""Get low res temporal cropped slices for cropping time index of
padded input data.
Returns
-------
_t_lr_crop_slices : list
List of low res temporal slices for cropping padded input data
"""
if self._t_lr_crop_slices is None:
self._t_lr_crop_slices = self.get_cropped_slices(
self.t_lr_slices, self.t_lr_pad_slices, 1
)
return self._t_lr_crop_slices
@property
def t_hr_crop_slices(self):
"""Get high res temporal cropped slices for cropping forward pass
output before stitching together
Returns
-------
_t_hr_crop_slices : list
List of high res temporal slices for cropping padded generator
output
"""
hr_crop_start = None
hr_crop_stop = None
if self.temporal_pad > 0:
hr_crop_start = self.t_enhance * self.temporal_pad
hr_crop_stop = -hr_crop_start
if self._t_hr_crop_slices is None:
# don't use self.get_cropped_slices() here because temporal padding
# gets weird at beginning and end of timeseries and the temporal
# axis should always be evenly chunked.
self._t_hr_crop_slices = [
slice(hr_crop_start, hr_crop_stop)
for _ in range(len(self.t_lr_slices))
]
return self._t_hr_crop_slices
@property
def s1_hr_slices(self):
"""Get high res spatial slices for first spatial dimension"""
return self.get_hr_slices(self.s1_lr_slices, self.s_enhance)
@property
def s2_hr_slices(self):
"""Get high res spatial slices for second spatial dimension"""
return self.get_hr_slices(self.s2_lr_slices, self.s_enhance)
@property
def s1_hr_crop_slices(self):
"""Get high res cropped slices for first spatial dimension"""
if self._s1_hr_crop_slices is None:
hr_crop_start = self.s_enhance * self.spatial_pad or None
hr_crop_stop = None if self.spatial_pad == 0 else -hr_crop_start
self._s1_hr_crop_slices = [
slice(hr_crop_start, hr_crop_stop)
] * len(self.s1_lr_slices)
self._s1_hr_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s1_lr_slices,
cropped_slices=self._s1_hr_crop_slices,
dim=0,
)
return self._s1_hr_crop_slices
@property
def s2_hr_crop_slices(self):
"""Get high res cropped slices for first spatial dimension"""
if self._s2_hr_crop_slices is None:
hr_crop_start = self.s_enhance * self.spatial_pad or None
hr_crop_stop = None if self.spatial_pad == 0 else -hr_crop_start
self._s2_hr_crop_slices = [
slice(hr_crop_start, hr_crop_stop)
] * len(self.s2_lr_slices)
self._s2_hr_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s2_lr_slices,
cropped_slices=self._s2_hr_crop_slices,
dim=1,
)
return self._s2_hr_crop_slices
@property
def s_hr_slices(self):
"""Get high res slices for indexing full generator output array
Returns
-------
_s_hr_slices : list
List of high res slices. Each entry in this list has a slice for
each spatial dimension. output[hr_slice] gives the superresolved
domain corresponding to data_handler.data[lr_slice]
"""
if self._s_hr_slices is None:
self._s_hr_slices = list(
it.product(self.s1_hr_slices, self.s2_hr_slices)
)
return self._s_hr_slices
@property
def s_lr_crop_slices(self):
"""Get low res cropped slices for cropping input chunk domain
Returns
-------
_s_lr_crop_slices : list
List of low res cropped slices. Each entry in this list has a
slice for each spatial dimension.
"""
if self._s_lr_crop_slices is None:
self._s_lr_crop_slices = []
s1_crop_slices = self.get_cropped_slices(
self.s1_lr_slices, self.s1_lr_pad_slices, 1
)
s1_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s1_lr_slices,
cropped_slices=s1_crop_slices,
dim=0,
)
s2_crop_slices = self.get_cropped_slices(
self.s2_lr_slices, self.s2_lr_pad_slices, 1
)
s2_crop_slices = self.check_boundary_slice(
unpadded_slices=self.s2_lr_slices,
cropped_slices=s2_crop_slices,
dim=1,
)
self._s_lr_crop_slices = list(
it.product(s1_crop_slices, s2_crop_slices)
)
return self._s_lr_crop_slices
@property
def s_hr_crop_slices(self):
"""Get high res cropped slices for cropping generator output
Returns
-------
_s_hr_crop_slices : list
List of high res cropped slices. Each entry in this list has a
slice for each spatial dimension.
"""
if self._s_hr_crop_slices is None:
self._s_hr_crop_slices = list(
it.product(self.s1_hr_crop_slices, self.s2_hr_crop_slices)
)
return self._s_hr_crop_slices
@property
def hr_crop_slices(self):
"""Get high res spatiotemporal cropped slices for cropping generator
output
Returns
-------
_hr_crop_slices : list
List of high res spatiotemporal cropped slices. Each entry in this
list has a crop slice for each spatial dimension and temporal
dimension and then slice(None) for the feature dimension.
model.generate()[hr_crop_slice] gives the cropped generator output
corresponding to outpuUnion[np.ndarray, da.core.Array][hr_slice]
"""
if self._hr_crop_slices is None:
self._hr_crop_slices = []
for t in self.t_hr_crop_slices:
node_slices = [
(s[0], s[1], t, slice(None)) for s in self.s_hr_crop_slices
]
self._hr_crop_slices.append(node_slices)
return self._hr_crop_slices
@property
def s1_lr_pad_slices(self):
"""List of low resolution spatial slices with padding for first
spatial dimension"""
if self._s1_lr_pad_slices is None:
self._s1_lr_pad_slices = self.get_padded_slices(
slices=self.s1_lr_slices,
shape=self.coarse_shape[0],
enhancement=1,
padding=self.spatial_pad,
)
return self._s1_lr_pad_slices
@property
def s2_lr_pad_slices(self):
"""List of low resolution spatial slices with padding for second
spatial dimension"""
if self._s2_lr_pad_slices is None:
self._s2_lr_pad_slices = self.get_padded_slices(
slices=self.s2_lr_slices,
shape=self.coarse_shape[1],
enhancement=1,
padding=self.spatial_pad,
)
return self._s2_lr_pad_slices
@property
def s1_lr_slices(self):
"""List of low resolution spatial slices for first spatial dimension
considering padding on all sides of the spatial raster."""
ind = slice(0, self.coarse_shape[0])
return get_chunk_slices(
self.coarse_shape[0], self.chunk_shape[0], index_slice=ind
)
@property
def s2_lr_slices(self):
"""List of low resolution spatial slices for second spatial dimension
considering padding on all sides of the spatial raster."""
ind = slice(0, self.coarse_shape[1])
return get_chunk_slices(
self.coarse_shape[1], self.chunk_shape[1], index_slice=ind
)
@property
def t_lr_slices(self):
"""Low resolution temporal slices"""
n_tsteps = len(self.dummy_time_index[self.time_slice])
n_chunks = n_tsteps / self.chunk_shape[2]
n_chunks = int(np.ceil(n_chunks))
ti_slices = self.dummy_time_index[self.time_slice]
ti_slices = np.array_split(ti_slices, n_chunks)
return [
slice(c[0], c[-1] + 1, self.time_slice.step) for c in ti_slices
]
[docs]
@staticmethod
def get_hr_slices(slices, enhancement, step=None):
"""Get high resolution slices for temporal or spatial slices
Parameters
----------
slices : list
Low resolution slices to be enhanced
enhancement : int
Enhancement factor
step : int | None
Step size for slices
Returns
-------
hr_slices : list
High resolution slices
"""
hr_slices = []
if step is not None:
step *= enhancement
for sli in slices:
start = sli.start * enhancement
stop = sli.stop * enhancement
hr_slices.append(slice(start, stop, step))
return hr_slices
@property
def chunk_lookup(self):
"""Get a 3D array with shape
(n_spatial_1_chunks, n_spatial_2_chunks, n_time_chunks)
where each value is the chunk index."""
if self._chunk_lookup is None:
n_s1 = len(self.s1_lr_slices)
n_s2 = len(self.s2_lr_slices)
n_t = self.n_time_chunks
lookup = np.arange(self.n_chunks).reshape((n_t, n_s1, n_s2))
self._chunk_lookup = np.transpose(lookup, axes=(1, 2, 0))
return self._chunk_lookup
@property
def spatial_chunk_lookup(self):
"""Get a 2D array with shape (n_spatial_1_chunks, n_spatial_2_chunks)
where each value is the spatial chunk index."""
n_s1 = len(self.s1_lr_slices)
n_s2 = len(self.s2_lr_slices)
return np.arange(self.n_spatial_chunks).reshape((n_s1, n_s2))
@property
def n_spatial_chunks(self):
"""Get the number of spatial chunks"""
return len(self.hr_crop_slices[0])
@property
def n_time_chunks(self):
"""Get the number of temporal chunks"""
return len(self.t_hr_crop_slices)
@property
def n_chunks(self):
"""Get total number of spatiotemporal chunks"""
return self.n_spatial_chunks * self.n_time_chunks
[docs]
@staticmethod
def get_padded_slices(slices, shape, enhancement, padding, step=None):
"""Get padded slices with the specified padding size, max shape,
enhancement, and step size
Parameters
----------
slices : list
List of low res unpadded slice
shape : int
max possible index of a padded slice. e.g. if the slices are
indexing a dimension with size 10 then a padded slice cannot have
an index greater than 10.
enhancement : int
Enhancement factor. e.g. If these slices are indexing a spatial
dimension which will be enhanced by 2x then enhancement=2.
padding : int
Padding factor. e.g. If these slices are indexing a spatial
dimension and the spatial_pad is 10 this is 10. It will be
multiplied by the enhancement factor if the slices are to be used
to index an enhanced dimension.
step : int | None
Step size for slices. e.g. If these slices are indexing a temporal
dimension and time_slice.step = 3 then step=3.
Returns
-------
list
Padded slices for temporal or spatial dimensions.
"""
step = step or 1
pad = step * padding * enhancement
pad_slices = []
for _, s in enumerate(slices):
end = np.min([enhancement * shape, s.stop * enhancement + pad])
start = np.max([0, s.start * enhancement - pad])
pad_slices.append(slice(start, end, step))
return pad_slices
[docs]
def check_boundary_slice(self, unpadded_slices, cropped_slices, dim):
"""Check cropped slice at the right boundary for minimum shape.
It is possible for the forward pass chunk shape to divide the grid size
such that the last slice (right boundary) does not meet the minimum
number of elements. (Padding layers in the generator typically require
a minimum shape of 4). e.g. ``grid_size = (8, 8)`` with
``fwp_chunk_shape = (7, 7, ...)`` results in unpadded slices with just
one element. If the padding is 0 or 1 these padded slices have length
less than 4. When this minimum shape is not met we apply extra padding
in :meth:`self._get_pad_width`. Cropped slices have to be adjusted to
account for this here."""
warn_msg = (
'The final spatial slice for dimension #%s is too small '
'(slice=slice(%s, %s), padding=%s). The start of this slice will '
'be reduced to try to meet the minimum slice length.'
)
lr_slice_start = unpadded_slices[-1].start or 0
lr_slice_stop = unpadded_slices[-1].stop or self.coarse_shape[dim]
# last slice adjustment
if 2 * self.spatial_pad + (lr_slice_stop - lr_slice_start) < 4:
logger.warning(
warn_msg,
dim + 1,
lr_slice_start,
lr_slice_stop,
self.spatial_pad,
)
warn(
warn_msg
% (dim + 1, lr_slice_start, lr_slice_stop, self.spatial_pad)
)
cropped_slices[-1] = slice(2 * self.s_enhance, -2 * self.s_enhance)
return cropped_slices
[docs]
@staticmethod
def get_cropped_slices(unpadded_slices, padded_slices, enhancement):
"""Get cropped slices to cut off padded output
Parameters
----------
unpadded_slices : list
List of unpadded slices
padded_slices : list
List of padded slices
enhancement : int
Enhancement factor for the data to be cropped.
Returns
-------
list
Cropped slices for temporal or spatial dimensions.
"""
cropped_slices = []
for ps, us in zip(padded_slices, unpadded_slices):
start = us.start
stop = us.stop
step = us.step or 1
if start is not None:
start = enhancement * (us.start - ps.start) // step
if stop is not None:
stop = enhancement * (us.stop - ps.stop) // step
if start is not None and start <= 0:
start = None
if stop is not None and stop >= 0:
stop = None
cropped_slices.append(slice(start, stop))
return cropped_slices
@staticmethod
def _get_pad_width(window, max_steps, max_pad, check_boundary=False):
"""
Parameters
----------
window : slice
Slice with start and stop of window to pad.
max_steps : int
Maximum number of steps available. Padding cannot extend past this
max_pad : int
Maximum amount of padding to apply.
check_bounary : bool
Whether to check the final slice for minimum size requirement
Returns
-------
tuple
Tuple of pad width for the given window.
"""
win_start = window.start or 0
win_stop = window.stop or max_steps
start = int(np.maximum(0, (max_pad - win_start)))
stop = int(np.maximum(0, max_pad + win_stop - max_steps))
# We add minimum padding to the last slice if the padded window is
# too small for the generator. This can happen if 2 * spatial_pad +
# modulo(grid_size, fwp_chunk_shape) < 4
if (
check_boundary
and win_stop == max_steps
and (2 * max_pad + win_stop - win_start) < 4
):
stop = np.max([2, max_pad])
start = np.max([2, max_pad])
return (start, stop)
[docs]
def get_chunk_indices(self, chunk_index):
"""Get (spatial, temporal) indices for the given chunk index"""
return (
chunk_index % self.n_spatial_chunks,
chunk_index // self.n_spatial_chunks,
)
[docs]
def get_pad_width(self, chunk_index):
"""Get extra padding for the current spatiotemporal chunk
Returns
-------
padding : tuple
Tuple of tuples with padding width for spatial and temporal
dimensions. Each tuple includes the start and end of padding for
that dimension. Ordering is spatial_1, spatial_2, temporal.
"""
s_chunk_idx, t_chunk_idx = self.get_chunk_indices(chunk_index)
ti_slice = self.t_lr_slices[t_chunk_idx]
lr_slice = self.s_lr_slices[s_chunk_idx]
return (
self._get_pad_width(
lr_slice[0],
self.coarse_shape[0],
self.spatial_pad,
check_boundary=True,
),
self._get_pad_width(
lr_slice[1],
self.coarse_shape[1],
self.spatial_pad,
check_boundary=True,
),
self._get_pad_width(
ti_slice, len(self.dummy_time_index), self.temporal_pad
),
)
@property
def extra_padding(self):
"""Get list of pad widths for each chunk index"""
if self._extra_padding is None:
self._extra_padding = [
self.get_pad_width(idx) for idx in range(self.n_chunks)
]
return self._extra_padding