Source code for sup3r.preprocessing.loaders.nc
"""Base loading classes. These are containers which also load data from
file_paths and include some sampling ability to interface with batcher
classes."""
import logging
from functools import cached_property
from warnings import warn
import dask.array as da
import numpy as np
from sup3r.preprocessing.names import COORD_NAMES, DIM_NAMES, Dimension
from sup3r.preprocessing.utilities import lower_names, ordered_dims
from sup3r.utilities.utilities import xr_open_mfdataset
from .base import BaseLoader
logger = logging.getLogger(__name__)
[docs]
class LoaderNC(BaseLoader):
"""Base NETCDF loader. "Loads" netcdf files so that a ``.data`` attribute
provides access to the data in the files. This object provides a
``__getitem__`` method that can be used by Sampler objects to build batches
or by other objects to derive / extract specific features / regions /
time_periods."""
[docs]
def BASE_LOADER(self, file_paths, **kwargs):
"""Lowest level interface to data."""
return xr_open_mfdataset(file_paths, **kwargs)
def _enforce_descending_lats(self, dset):
"""Make sure latitudes are in descending order so that min lat / lon is
at ``lat_lon[-1, 0]``."""
invert_lats = not self._is_flattened and (
dset[Dimension.LATITUDE][-1, 0] > dset[Dimension.LATITUDE][0, 0]
)
if invert_lats:
for var in [*list(Dimension.coords_2d()), *list(dset.data_vars)]:
if Dimension.SOUTH_NORTH in dset[var].dims:
new_var = dset[var].isel(south_north=slice(None, None, -1))
dset.update({var: new_var})
return dset
def _enforce_descending_levels(self, dset):
"""Make sure levels are in descending order so that max pressure is at
``level[0]``."""
invert_levels = (
dset[Dimension.PRESSURE_LEVEL][-1]
> dset[Dimension.PRESSURE_LEVEL][0]
if Dimension.PRESSURE_LEVEL in dset
else False
)
if invert_levels:
for var in list(dset.data_vars):
if Dimension.PRESSURE_LEVEL in dset[var].dims:
new_var = dset[var].isel(
{Dimension.PRESSURE_LEVEL: slice(None, None, -1)}
)
dset.update(
{var: (dset[var].dims, new_var.data, dset[var].attrs)}
)
new_press = dset[Dimension.PRESSURE_LEVEL][::-1]
dset.update({Dimension.PRESSURE_LEVEL: new_press})
return dset
@cached_property
def _lat_lon_shape(self):
"""Get shape of lat lon grid only."""
return self._res[Dimension.LATITUDE].shape
@cached_property
def _is_flattened(self):
"""Check if dims include a single spatial dimension."""
check = (
len(self._lat_lon_shape) == 1
and self._res[Dimension.LATITUDE].dims
== self._res[Dimension.LONGITUDE].dims
)
return check
def _get_coords(self, res):
"""Get coordinate dictionary to use in
``xr.Dataset().assign_coords()``."""
lats = res[Dimension.LATITUDE].data.astype(np.float32)
lons = res[Dimension.LONGITUDE].data.astype(np.float32)
# remove time dimension if there's a single time step
if lats.ndim == 3:
lats = lats.squeeze()
if lons.ndim == 3:
lons = lons.squeeze()
if len(lats.shape) == 1 and not self._is_flattened:
lons, lats = da.meshgrid(lons, lats)
dim_names = self._lat_lon_dims
coords = {
Dimension.LATITUDE: (dim_names, lats),
Dimension.LONGITUDE: (dim_names, lons),
}
if Dimension.TIME in res:
if Dimension.TIME in res.indexes:
times = res.indexes[Dimension.TIME]
else:
times = res[Dimension.TIME]
if hasattr(times, 'to_datetimeindex'):
times = times.to_datetimeindex()
coords[Dimension.TIME] = times
return coords
def _get_dims(self, res):
"""Get dimension name map using our standard mappping and the names
used for coordinate dimensions."""
rename_dims = {k: v for k, v in DIM_NAMES.items() if k in res.dims}
lat_dims = res[Dimension.LATITUDE].dims
lon_dims = res[Dimension.LONGITUDE].dims
if len(lat_dims) == 1 and len(lon_dims) == 1:
dim_names = self._lat_lon_dims
rename_dims[lat_dims[0]] = dim_names[0]
rename_dims[lon_dims[0]] = dim_names[-1]
else:
msg = (
'Latitude and Longitude dimension names are different. '
'This is weird.'
)
if lon_dims != lat_dims:
logger.warning(msg)
warn(msg)
else:
rename_dims.update(
dict(zip(ordered_dims(lat_dims), Dimension.dims_2d()))
)
return rename_dims
def _rechunk_dsets(self, res):
"""Apply given chunk values for each field in res.coords and
res.data_vars."""
for dset in [*list(res.coords), *list(res.data_vars)]:
chunks = self._parse_chunks(dims=res[dset].dims, feature=dset)
# specifying chunks to xarray.open_mfdataset doesn't automatically
# apply to coordinates so we do that here
if chunks != 'auto' or dset in Dimension.coords_2d():
res[dset] = res[dset].chunk(chunks)
return res
def _load(self):
"""Load netcdf ``xarray.Dataset()``."""
res = lower_names(self._res)
rename_coords = {
k: v for k, v in COORD_NAMES.items() if k in res and v not in res
}
self._res = res.rename(rename_coords)
if not all(coord in self._res for coord in Dimension.coords_2d()):
err = 'Could not find valid coordinates in given files: %s'
logger.error(err, self.file_paths)
raise OSError(err % (self.file_paths))
res = self._res.swap_dims(self._get_dims(self._res))
res = res.assign_coords(self._get_coords(res))
res = self._enforce_descending_lats(res)
res = self._rechunk_dsets(res)
return self._enforce_descending_levels(res).astype(np.float32)