Source code for sup3r.preprocessing.loaders.nc
"""Base loading classes. These are containers which also load data from
file_paths and include some sampling ability to interface with batcher
classes."""
import logging
from functools import cached_property
from warnings import warn
import dask.array as da
import numpy as np
from sup3r.preprocessing.names import COORD_NAMES, DIM_NAMES, Dimension
from sup3r.preprocessing.utilities import lower_names, ordered_dims
from sup3r.utilities.utilities import xr_open_mfdataset
from .base import BaseLoader
logger = logging.getLogger(__name__)
[docs]
class LoaderNC(BaseLoader):
    """Base NETCDF loader. "Loads" netcdf files so that a ``.data`` attribute
    provides access to the data in the files. This object provides a
    ``__getitem__`` method that can be used by Sampler objects to build batches
    or by other objects to derive / extract specific features / regions /
    time_periods."""
[docs]
    def BASE_LOADER(self, file_paths, **kwargs):
        """Lowest level interface to data."""
        return xr_open_mfdataset(file_paths, **kwargs) 
    def _enforce_descending_lats(self, dset):
        """Make sure latitudes are in descending order so that min lat / lon is
        at ``lat_lon[-1, 0]``."""
        invert_lats = not self._is_flattened and (
            dset[Dimension.LATITUDE][-1, 0] > dset[Dimension.LATITUDE][0, 0]
        )
        if invert_lats:
            for var in [*list(Dimension.coords_2d()), *list(dset.data_vars)]:
                if Dimension.SOUTH_NORTH in dset[var].dims:
                    new_var = dset[var].isel(south_north=slice(None, None, -1))
                    dset.update({var: new_var})
        return dset
    def _enforce_descending_levels(self, dset):
        """Make sure levels are in descending order so that max pressure is at
        ``level[0]``."""
        invert_levels = (
            dset[Dimension.PRESSURE_LEVEL][-1]
            > dset[Dimension.PRESSURE_LEVEL][0]
            if Dimension.PRESSURE_LEVEL in dset
            else False
        )
        if invert_levels:
            for var in list(dset.data_vars):
                if Dimension.PRESSURE_LEVEL in dset[var].dims:
                    new_var = dset[var].isel(
                        {Dimension.PRESSURE_LEVEL: slice(None, None, -1)}
                    )
                    dset.update(
                        {var: (dset[var].dims, new_var.data, dset[var].attrs)}
                    )
            new_press = dset[Dimension.PRESSURE_LEVEL][::-1]
            dset.update({Dimension.PRESSURE_LEVEL: new_press})
        return dset
    @cached_property
    def _lat_lon_shape(self):
        """Get shape of lat lon grid only."""
        return self._res[Dimension.LATITUDE].shape
    @cached_property
    def _is_flattened(self):
        """Check if dims include a single spatial dimension."""
        check = (
            len(self._lat_lon_shape) == 1
            and self._res[Dimension.LATITUDE].dims
            == self._res[Dimension.LONGITUDE].dims
        )
        return check
    def _get_coords(self, res):
        """Get coordinate dictionary to use in
        ``xr.Dataset().assign_coords()``."""
        lats = res[Dimension.LATITUDE].data.astype(np.float32)
        lons = res[Dimension.LONGITUDE].data.astype(np.float32)
        # remove time dimension if there's a single time step
        if lats.ndim == 3:
            lats = lats.squeeze()
        if lons.ndim == 3:
            lons = lons.squeeze()
        if len(lats.shape) == 1 and not self._is_flattened:
            lons, lats = da.meshgrid(lons, lats)
        dim_names = self._lat_lon_dims
        coords = {
            Dimension.LATITUDE: (dim_names, lats),
            Dimension.LONGITUDE: (dim_names, lons),
        }
        if Dimension.TIME in res:
            if Dimension.TIME in res.indexes:
                times = res.indexes[Dimension.TIME]
            else:
                times = res[Dimension.TIME]
            if hasattr(times, 'to_datetimeindex'):
                times = times.to_datetimeindex()
            coords[Dimension.TIME] = times
        return coords
    def _get_dims(self, res):
        """Get dimension name map using our standard mappping and the names
        used for coordinate dimensions."""
        rename_dims = {k: v for k, v in DIM_NAMES.items() if k in res.dims}
        lat_dims = res[Dimension.LATITUDE].dims
        lon_dims = res[Dimension.LONGITUDE].dims
        if len(lat_dims) == 1 and len(lon_dims) == 1:
            dim_names = self._lat_lon_dims
            rename_dims[lat_dims[0]] = dim_names[0]
            rename_dims[lon_dims[0]] = dim_names[-1]
        else:
            msg = (
                'Latitude and Longitude dimension names are different. '
                'This is weird.'
            )
            if lon_dims != lat_dims:
                logger.warning(msg)
                warn(msg)
            else:
                rename_dims.update(
                    dict(zip(ordered_dims(lat_dims), Dimension.dims_2d()))
                )
        return rename_dims
    def _rechunk_dsets(self, res):
        """Apply given chunk values for each field in res.coords and
        res.data_vars."""
        for dset in [*list(res.coords), *list(res.data_vars)]:
            chunks = self._parse_chunks(dims=res[dset].dims, feature=dset)
            # specifying chunks to xarray.open_mfdataset doesn't automatically
            # apply to coordinates so we do that here
            if chunks != 'auto' or dset in Dimension.coords_2d():
                res[dset] = res[dset].chunk(chunks)
        return res
    def _load(self):
        """Load netcdf ``xarray.Dataset()``."""
        res = lower_names(self._res)
        rename_coords = {
            k: v for k, v in COORD_NAMES.items() if k in res and v not in res
        }
        self._res = res.rename(rename_coords)
        if not all(coord in self._res for coord in Dimension.coords_2d()):
            err = 'Could not find valid coordinates in given files: %s'
            logger.error(err, self.file_paths)
            raise OSError(err % (self.file_paths))
        res = self._res.swap_dims(self._get_dims(self._res))
        res = res.assign_coords(self._get_coords(res))
        res = self._enforce_descending_lats(res)
        res = self._rechunk_dsets(res)
        return self._enforce_descending_levels(res).astype(np.float32)