Source code for sup3r.preprocessing.data_handling.exo_extraction
"""Sup3r topography utilities"""
import logging
import os
import pickle
import shutil
from abc import ABC, abstractmethod
from warnings import warn
import pandas as pd
import numpy as np
from rex import Resource
from rex.utilities.solar_position import SolarPosition
from scipy.spatial import KDTree
import sup3r.preprocessing.data_handling
from sup3r.postprocessing.file_handling import OutputHandler
from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5
from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC
from sup3r.utilities.utilities import (generate_random_string, get_source_type,
nn_fill_array)
logger = logging.getLogger(__name__)
[docs]
class ExoExtract(ABC):
"""Class to extract high-res (4km+) data rasters for new
spatially-enhanced datasets (e.g. GCM files after spatial enhancement)
using nearest neighbor mapping and aggregation from NREL datasets
(e.g. WTK or NSRDB)
"""
def __init__(self,
file_paths,
exo_source,
s_enhance,
t_enhance,
t_agg_factor,
target=None,
shape=None,
temporal_slice=None,
raster_file=None,
max_delta=20,
input_handler=None,
cache_data=True,
cache_dir='./exo_cache/',
ti_workers=1,
distance_upper_bound=None,
res_kwargs=None):
"""Parameters
----------
file_paths : str | list
A single source h5 file to extract raster data from or a list
of netcdf files with identical grid. The string can be a unix-style
file path which will be passed through glob.glob. This is
typically low-res WRF output or GCM netcdf data files that is
source low-resolution data intended to be sup3r resolved.
exo_source : str
Filepath to source data file to get hi-res elevation data from
which will be mapped to the enhanced grid of the file_paths input.
Pixels from this exo_source will be mapped to their nearest low-res
pixel in the file_paths input. Accordingly, exo_source should be a
significantly higher resolution than file_paths. Warnings will be
raised if the low-resolution pixels in file_paths do not have
unique nearest pixels from exo_source. File format can be .h5 for
TopoExtractH5 or .nc for TopoExtractNC
s_enhance : int
Factor by which the Sup3rGan model will enhance the spatial
dimensions of low resolution data from file_paths input. For
example, if getting topography data, file_paths has 100km data, and
s_enhance is 4, this class will output a topography raster
corresponding to the file_paths grid enhanced 4x to ~25km
t_enhance : int
Factor by which the Sup3rGan model will enhance the temporal
dimension of low resolution data from file_paths input. For
example, if getting sza data, file_paths has hourly data, and
t_enhance is 4, this class will output a sza raster
corresponding to the file_paths temporally enhanced 4x to 15 min
t_agg_factor : int
Factor by which to aggregate / subsample the exo_source data to the
resolution of the file_paths input enhanced by t_enhance. For
example, if getting sza data, file_paths have hourly data, and
t_enhance is 4 resulting in a target resolution of 15 min and
exo_source has a resolution of 5 min, the t_agg_factor should be 3
so that only timesteps that are a multiple of 15min are selected
e.g., [0, 5, 10, 15, 20, 25, 30][slice(0, None, 3)] = [0, 15, 30]
target : tuple
(lat, lon) lower left corner of raster. Either need target+shape or
raster_file.
shape : tuple
(rows, cols) grid size. Either need target+shape or raster_file.
temporal_slice : slice | None
slice used to extract interval from temporal dimension for input
data and source data
raster_file : str | None
File for raster_index array for the corresponding target and shape.
If specified the raster_index will be loaded from the file if it
exists or written to the file if it does not yet exist. If None
raster_index will be calculated directly. Either need target+shape
or raster_file.
max_delta : int, optional
Optional maximum limit on the raster shape that is retrieved at
once. If shape is (20, 20) and max_delta=10, the full raster will
be retrieved in four chunks of (10, 10). This helps adapt to
non-regular grids that curve over large distances, by default 20
input_handler : str
data handler class to use for input data. Provide a string name to
match a class in data_handling.py. If None the correct handler will
be guessed based on file type and time series properties.
cache_data : bool
Flag to cache exogeneous data in <cache_dir>/exo_cache/ this can
speed up forward passes with large temporal extents when the exo
data is time independent.
cache_dir : str
Directory for storing cache data. Default is './exo_cache'
ti_workers : int | None
max number of workers to use to get full time index. Useful when
there are many input files each with a single time step. If this is
greater than one, time indices for input files will be extracted in
parallel and then concatenated to get the full time index. If input
files do not all have time indices or if there are few input files
this should be set to one.
distance_upper_bound : float | None
Maximum distance to map high-resolution data from exo_source to the
low-resolution file_paths input. None (default) will calculate this
based on the median distance between points in exo_source
res_kwargs : dict | None
Dictionary of kwargs passed to lowest level resource handler. e.g.
xr.open_dataset(file_paths, **res_kwargs)
"""
logger.info(f'Initializing {self.__class__.__name__} utility.')
self.ti_workers = ti_workers
self._exo_source = exo_source
self._s_enhance = s_enhance
self._t_enhance = t_enhance
self._t_agg_factor = t_agg_factor
self._tree = None
self._hr_lat_lon = None
self._source_lat_lon = None
self._hr_time_index = None
self._src_time_index = None
self._distance_upper_bound = distance_upper_bound
self.cache_data = cache_data
self.cache_dir = cache_dir
self.temporal_slice = temporal_slice
self.target = target
self.shape = shape
self.res_kwargs = res_kwargs
# for subclasses
self._source_handler = None
if input_handler is None:
in_type = get_source_type(file_paths)
if in_type == 'nc':
input_handler = DataHandlerNC
elif in_type == 'h5':
input_handler = DataHandlerH5
else:
msg = (f'Did not recognize input type "{in_type}" for file '
f'paths: {file_paths}')
logger.error(msg)
raise RuntimeError(msg)
elif isinstance(input_handler, str):
input_handler = getattr(sup3r.preprocessing.data_handling,
input_handler, None)
if input_handler is None:
msg = ('Could not find requested data handler class '
f'"{input_handler}" in '
'sup3r.preprocessing.data_handling.')
logger.error(msg)
raise KeyError(msg)
self.input_handler = input_handler(
file_paths, [],
target=target,
shape=shape,
temporal_slice=temporal_slice,
raster_file=raster_file,
max_delta=max_delta,
worker_kwargs={'ti_workers': ti_workers},
res_kwargs=self.res_kwargs
)
@property
@abstractmethod
def source_data(self):
"""Get the 1D array of source data from the exo_source_h5"""
[docs]
def get_cache_file(self, feature, s_enhance, t_enhance, t_agg_factor):
"""Get cache file name
Parameters
----------
feature : str
Name of feature to get cache file for
s_enhance : int
Spatial enhancement for this exogeneous data step (cumulative for
all model steps up to the current step).
t_enhance : int
Temporal enhancement for this exogeneous data step (cumulative for
all model steps up to the current step).
t_agg_factor : int
Factor by which to aggregate the exo_source data to the temporal
resolution of the file_paths input enhanced by t_enhance.
Returns
-------
cache_fp : str
Name of cache file
"""
tsteps = (None if self.temporal_slice is None
or self.temporal_slice.start is None
or self.temporal_slice.stop is None
else self.temporal_slice.stop - self.temporal_slice.start)
fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}'
fn += f'_tagg{t_agg_factor}_{s_enhance}x_'
fn += f'{t_enhance}x.pkl'
fn = fn.replace('(', '').replace(')', '')
fn = fn.replace('[', '').replace(']', '')
fn = fn.replace(',', 'x').replace(' ', '')
cache_fp = os.path.join(self.cache_dir, fn)
if self.cache_data:
os.makedirs(self.cache_dir, exist_ok=True)
return cache_fp
@property
def source_temporal_slice(self):
"""Get the temporal slice for the exo_source data corresponding to the
input file temporal slice
"""
start_index = self.source_time_index.get_indexer(
[self.input_handler.hr_time_index[0]], method='nearest')[0]
end_index = self.source_time_index.get_indexer(
[self.input_handler.hr_time_index[-1]], method='nearest')[0]
return slice(start_index, end_index + 1, self._t_agg_factor)
@property
def source_lat_lon(self):
"""Get the 2D array (n, 2) of lat, lon data from the exo_source_h5"""
with Resource(self._exo_source) as res:
source_lat_lon = res.lat_lon
return source_lat_lon
@property
def lr_shape(self):
"""Get the low-resolution spatial shape tuple"""
return (self.lr_lat_lon.shape[0], self.lr_lat_lon.shape[1],
len(self.input_handler.time_index))
@property
def hr_shape(self):
"""Get the high-resolution spatial shape tuple"""
return (self._s_enhance * self.lr_lat_lon.shape[0],
self._s_enhance * self.lr_lat_lon.shape[1],
self._t_enhance * len(self.input_handler.time_index))
@property
def lr_lat_lon(self):
"""Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon
array with same ordering in last dimension. This corresponds to the raw
meta data from the file_paths input.
Returns
-------
ndarray
"""
return self.input_handler.lat_lon
@property
def hr_lat_lon(self):
"""Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon
array with same ordering in last dimension. This corresponds to the
enhanced meta data from the file_paths input * s_enhance.
Returns
-------
ndarray
"""
if self._hr_lat_lon is None:
if self._s_enhance > 1:
self._hr_lat_lon = OutputHandler.get_lat_lon(
self.lr_lat_lon, self.hr_shape[:-1])
else:
self._hr_lat_lon = self.lr_lat_lon
return self._hr_lat_lon
@property
def source_time_index(self):
"""Get the full time index of the exo_source data"""
if self._src_time_index is None:
if self._t_agg_factor > 1:
self._src_time_index = OutputHandler.get_times(
self.input_handler.time_index,
self.hr_shape[-1] * self._t_agg_factor)
else:
self._src_time_index = self.hr_time_index
return self._src_time_index
@property
def hr_time_index(self):
"""Get the full time index for aggregated source data"""
if self._hr_time_index is None:
if self._t_enhance > 1:
self._hr_time_index = OutputHandler.get_times(
self.input_handler.time_index, self.hr_shape[-1])
else:
self._hr_time_index = self.input_handler.time_index
return self._hr_time_index
@property
def distance_upper_bound(self):
"""Maximum distance (float) to map high-resolution data from exo_source
to the low-resolution file_paths input."""
if self._distance_upper_bound is None:
diff = np.diff(self.source_lat_lon, axis=0)
diff = np.max(np.median(diff, axis=0))
self._distance_upper_bound = diff
logger.info('Set distance upper bound to {:.4f}'
.format(self._distance_upper_bound))
return self._distance_upper_bound
@property
def tree(self):
"""Get the KDTree built on the target lat lon data from the file_paths
input with s_enhance"""
if self._tree is None:
lat = self.hr_lat_lon[..., 0].flatten()
lon = self.hr_lat_lon[..., 1].flatten()
hr_meta = np.vstack((lat, lon)).T
self._tree = KDTree(hr_meta)
return self._tree
@property
def nn(self):
"""Get the nearest neighbor indices"""
_, nn = self.tree.query(self.source_lat_lon, k=1,
distance_upper_bound=self.distance_upper_bound)
return nn
@property
def data(self):
"""Get a raster of source values corresponding to the
high-resolution grid (the file_paths input grid * s_enhance *
t_enhance). The shape is (lats, lons, temporal, 1)
"""
cache_fp = self.get_cache_file(feature=self.__class__.__name__,
s_enhance=self._s_enhance,
t_enhance=self._t_enhance,
t_agg_factor=self._t_agg_factor)
tmp_fp = cache_fp + f'.{generate_random_string(10)}.tmp'
if os.path.exists(cache_fp):
with open(cache_fp, 'rb') as f:
data = pickle.load(f)
else:
data = self.get_data()
if self.cache_data:
with open(tmp_fp, 'wb') as f:
pickle.dump(data, f)
shutil.move(tmp_fp, cache_fp)
if data.shape[-1] == 1 and self.hr_shape[-1] > 1:
data = np.repeat(data, self.hr_shape[-1], axis=-1)
return data[..., np.newaxis]
[docs]
@abstractmethod
def get_data(self):
"""Get a raster of source values corresponding to the
high-resolution grid (the file_paths input grid * s_enhance *
t_enhance). The shape is (lats, lons, temporal)
"""
[docs]
@classmethod
def get_exo_raster(cls,
file_paths,
s_enhance,
t_enhance,
t_agg_factor,
exo_source=None,
target=None,
shape=None,
temporal_slice=None,
raster_file=None,
max_delta=20,
input_handler=None,
cache_data=True,
cache_dir='./exo_cache/'):
"""Get the exo feature raster corresponding to the spatially enhanced
grid from the file_paths input
Parameters
----------
file_paths : str | list
A single source h5 file to extract raster data from or a list
of netcdf files with identical grid. The string can be a unix-style
file path which will be passed through glob.glob
s_enhance : int
Factor by which the Sup3rGan model will enhance the spatial
dimensions of low resolution data from file_paths input. For
example, if file_paths has 100km data and s_enhance is 4, this
class will output a topography raster corresponding to the
file_paths grid enhanced 4x to ~25km
t_enhance : int
Factor by which the Sup3rGan model will enhance the temporal
dimension of low resolution data from file_paths input. For
example, if getting sza data, file_paths has hourly data, and
t_enhance is 4, this class will output a sza raster
corresponding to the file_paths temporally enhanced 4x to 15 min
t_agg_factor : int
Factor by which to aggregate the exo_source data to the resolution
of the file_paths input enhanced by t_enhance. For example, if
getting sza data, file_paths have hourly data, and t_enhance
is 4 resulting in a desired resolution of 5 min and exo_source
has a resolution of 5 min, the t_agg_factor should be 4 so that
every fourth timestep in the exo_source data is skipped.
exo_source : str
Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or
4km) data from which will be mapped to the enhanced grid of the
file_paths input
target : tuple
(lat, lon) lower left corner of raster. Either need target+shape or
raster_file.
shape : tuple
(rows, cols) grid size. Either need target+shape or raster_file.
temporal_slice : slice | None
slice used to extract interval from temporal dimension for input
data and source data
raster_file : str | None
File for raster_index array for the corresponding target and shape.
If specified the raster_index will be loaded from the file if it
exists or written to the file if it does not yet exist. If None
raster_index will be calculated directly. Either need target+shape
or raster_file.
max_delta : int, optional
Optional maximum limit on the raster shape that is retrieved at
once. If shape is (20, 20) and max_delta=10, the full raster will
be retrieved in four chunks of (10, 10). This helps adapt to
non-regular grids that curve over large distances, by default 20
input_handler : str
data handler class to use for input data. Provide a string name to
match a class in data_handling.py. If None the correct handler will
be guessed based on file type and time series properties.
cache_data : bool
Flag to cache exogeneous data in <cache_dir>/exo_cache/ this can
speed up forward passes with large temporal extents when the exo
data is time independent.
cache_dir : str
Directory for storing cache data. Default is './exo_cache'
Returns
-------
exo_raster : np.ndarray
Exo feature raster with shape (hr_rows, hr_cols, h_temporal)
corresponding to the shape of the spatiotemporally enhanced data
from file_paths * s_enhance * t_enhance. The data units correspond
to the source units in exo_source_h5. This is usually meters when
feature='topography'
"""
exo = cls(file_paths,
s_enhance,
t_enhance,
t_agg_factor,
exo_source=exo_source,
target=target,
shape=shape,
temporal_slice=temporal_slice,
raster_file=raster_file,
max_delta=max_delta,
input_handler=input_handler,
cache_data=cache_data,
cache_dir=cache_dir)
return exo.data
[docs]
class TopoExtractH5(ExoExtract):
"""TopoExtract for H5 files"""
@property
def source_data(self):
"""Get the 1D array of elevation data from the exo_source_h5"""
with Resource(self._exo_source) as res:
elev = res.get_meta_arr('elevation')
return elev[:, np.newaxis]
@property
def source_time_index(self):
"""Time index of the source exo data"""
if self._src_time_index is None:
with Resource(self._exo_source) as res:
self._src_time_index = res.time_index
return self._src_time_index
[docs]
def get_data(self):
"""Get a raster of source values corresponding to the
high-resolution grid (the file_paths input grid * s_enhance *
t_enhance). The shape is (lats, lons, 1)
"""
assert len(self.source_data.shape) == 2
assert self.source_data.shape[1] == 1
df = pd.DataFrame({'topo': self.source_data.flatten(),
'gid_target': self.nn})
n_target = np.prod(self.hr_shape[:-1])
df = df[df['gid_target'] != n_target]
df = df.sort_values('gid_target')
df = df.groupby('gid_target').mean()
missing = set(np.arange(n_target)) - set(df.index)
if any(missing):
msg = (f'{len(missing)} target pixels did not have unique '
'high-resolution source data to map from. If there are a '
'lot of target pixels missing source data this probably '
'means the source data is not high enough resolution. '
'Filling raster with NN.')
logger.warning(msg)
warn(msg)
temp_df = pd.DataFrame({'topo': np.nan}, index=sorted(missing))
df = pd.concat((df, temp_df)).sort_index()
hr_data = df['topo'].values.reshape(self.hr_shape[:-1])
if np.isnan(hr_data).any():
hr_data = nn_fill_array(hr_data)
hr_data = np.expand_dims(hr_data, axis=-1)
logger.info('Finished mapping raster from {}'.format(self._exo_source))
return hr_data
[docs]
def get_cache_file(self, feature, s_enhance, t_enhance, t_agg_factor):
"""Get cache file name. This uses a time independent naming convention.
Parameters
----------
feature : str
Name of feature to get cache file for
s_enhance : int
Spatial enhancement for this exogeneous data step (cumulative for
all model steps up to the current step).
t_enhance : int
Temporal enhancement for this exogeneous data step (cumulative for
all model steps up to the current step).
t_agg_factor : int
Factor by which to aggregate the exo_source data to the temporal
resolution of the file_paths input enhanced by t_enhance.
Returns
-------
cache_fp : str
Name of cache file
"""
fn = f'exo_{feature}_{self.target}_{self.shape}'
fn += f'_tagg{t_agg_factor}_{s_enhance}x_'
fn += f'{t_enhance}x.pkl'
fn = fn.replace('(', '').replace(')', '')
fn = fn.replace('[', '').replace(']', '')
fn = fn.replace(',', 'x').replace(' ', '')
cache_fp = os.path.join(self.cache_dir, fn)
if self.cache_data:
os.makedirs(self.cache_dir, exist_ok=True)
return cache_fp
[docs]
class TopoExtractNC(TopoExtractH5):
"""TopoExtract for netCDF files"""
@property
def source_handler(self):
"""Get the DataHandlerNC object that handles the .nc source topography
data file."""
if self._source_handler is None:
logger.info('Getting topography for full domain from '
f'{self._exo_source}')
self._source_handler = DataHandlerNC(
self._exo_source,
features=['topography'],
worker_kwargs={'ti_workers': self.ti_workers},
val_split=0.0,
)
return self._source_handler
@property
def source_data(self):
"""Get the 1D array of elevation data from the exo_source_nc"""
elev = self.source_handler.data[..., 0, 0].flatten()
return elev[..., np.newaxis]
@property
def source_lat_lon(self):
"""Get the 2D array (n, 2) of lat, lon data from the exo_source_nc"""
source_lat_lon = self.source_handler.lat_lon.reshape((-1, 2))
return source_lat_lon
[docs]
class SzaExtract(ExoExtract):
"""SzaExtract for H5 files"""
@property
def source_data(self):
"""Get the 1D array of sza data from the exo_source_h5"""
return SolarPosition(self.hr_time_index,
self.hr_lat_lon.reshape((-1, 2))).zenith.T
[docs]
def get_data(self):
"""Get a raster of source values corresponding to the
high-resolution grid (the file_paths input grid * s_enhance *
t_enhance). The shape is (lats, lons, temporal)
"""
hr_data = self.source_data.reshape(self.hr_shape)
logger.info('Finished computing SZA data')
return hr_data