"""Code for regridding data from one list of coordinates to another"""
import logging
import pickle
import pprint
from functools import cached_property
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, asdict
from datetime import datetime as dt
from typing import Optional
import numpy as np
import pandas as pd
import psutil
from sklearn.neighbors import BallTree
logger = logging.getLogger(__name__)
class _InterpolationMixin:
"""Inverse-weighted distance interpolation logic.
This mixin class is only intended to be used with classes that
have the following attributes:
- self.distances: ndarray of distances from ball tree query
- self.indices: ndarray of indices from ball tree query
- self.min_distance: float representing the minimum distance to
use for inverse-weighted distances
calculation to avoid diving by 0
"""
@cached_property
def weights(self):
"""ndarray: Weights used for regridding. """
return _compute_weights(self.distances, self.min_distance)
def __call__(self, data):
"""Regrid given spatiotemporal data over entire grid.
Parameters
----------
data : :obj:`numpy.ndarray` | :obj:`dask.core.array.Array`
Spatiotemporal data to regrid to target_meta. Data can be
flattened in the spatial dimension to match the
`target_meta` or be in a 2D spatial grid, e.g.:
(spatial, temporal) or (spatial_1, spatial_2, temporal)
Returns
-------
out : :obj:`numpy.ndarray` | :obj:`dask.core.array.Array`
Flattened regridded spatiotemporal data
(spatial, temporal)
"""
if len(data.shape) == 3:
data = data.reshape((data.shape[0] * data.shape[1], -1))
msg = "Input data must be 2D (spatial, temporal)"
assert len(data.shape) == 2, msg
if hasattr(data, "compute"): # data is Dask array
shape = (len(self.indices), self.k_neighbors, data.shape[-1])
vals = data[np.concatenate(self.indices)].reshape(shape)
else:
vals = data[self.indices]
vals = np.transpose(vals, (2, 0, 1))
return np.einsum('ijk,jk->ij', vals, self.weights).T
# pylint: disable=attribute-defined-outside-init
[docs]
@dataclass
class Regridder(_InterpolationMixin):
"""Interpolate from one grid to another using inverse weighted distances.
This class builds ball tree and runs all queries to create full
arrays of indices and distances for neighbor points. It computes
an array of weights used to interpolate from the old grid to the new
grid.
Parameters
----------
source_meta : :class:`pandas.DataFrame`
Set of coordinates for source grid. Must contain "latitude"
and "longitude" columns representing the coordinates
(in degrees).
target_meta : :class:`pandas.DataFrame`
Set of coordinates for target grid. Must contain "latitude"
and "longitude" columns representing the coordinates
(in degrees).
k_neighbors : int, optional
Number of nearest neighbors to use for interpolation.
By default, ``4``.
n_chunks : int
Number of spatial chunks to use for tree queries. The total
number of points in the `target_meta` will be split into
`n_chunks`, and the points in each chunk will be queried at the
same time. By default, ``100``.
max_workers : int, optional
Max number of workers to use for running all tree queries needed
to build the full set of indices and distances for each
`target_meta` coordinate. By default, ``None``, which uses all
available CPU cores.
min_distance : float, optional
Minimum distance to use for inverse-weighted distances
calculation to avoid diving by 0. By default, ``1e-12``.
leaf_size : int, optional
Leaf size for :class:`~sklearn.neighbors.BallTree` instance.
By default, ``4``.
"""
source_meta: pd.DataFrame
target_meta: pd.DataFrame
k_neighbors: Optional[int] = 4
n_chunks: Optional[int] = 100
max_workers: Optional[int] = None
min_distance: Optional[float] = 1e-12
leaf_size: Optional[int] = 4
def __post_init__(self):
self._tree = None
self._distances = None
self._indices = None
self._weights = None
fields = pprint.pformat(asdict(self), indent=2)
logger.info("Initialized `Regridder` with:\n%s", fields)
@property
def distances(self):
"""Get distances for all tree queries."""
if self._distances is None:
self.init_queries()
return self._distances
@property
def indices(self):
"""Get indices for all tree queries."""
if self._indices is None:
self.init_queries()
return self._indices
[docs]
def init_queries(self):
"""Initialize arrays for tree queries and perform all queries"""
self._indices = [None] * len(self.target_meta)
self._distances = [None] * len(self.target_meta)
self.get_all_queries(self.max_workers)
@property
def tree(self):
"""Build ball tree from source_meta"""
if self._tree is None:
logger.info("Building ball tree for regridding.")
ll2 = self.source_meta[["latitude", "longitude"]].values
ll2 = np.radians(ll2)
self._tree = BallTree(ll2, leaf_size=self.leaf_size,
metric="haversine")
return self._tree
[docs]
def get_all_queries(self, max_workers=None):
"""Query ball tree for all coordinates in the target_meta and store
results"""
if max_workers == 1:
logger.info("Querying all coordinates in serial.")
self.save_query(slice(None))
else:
logger.info("Querying all coordinates in parallel.")
self._parallel_queries(max_workers=max_workers)
logger.info("Finished querying all coordinates.")
def _parallel_queries(self, max_workers=None):
"""Get indices and distances for all points in target_meta, in
serial"""
futures = {}
now = dt.now()
slices = np.arange(len(self.target_meta))
slices = np.array_split(slices, min(self.n_chunks, len(slices)))
slices = [slice(s[0], s[-1] + 1) for s in slices]
with ThreadPoolExecutor(max_workers=max_workers) as exe:
for i, s_slice in enumerate(slices):
future = exe.submit(self.save_query, s_slice=s_slice)
futures[future] = i
mem = psutil.virtual_memory()
msg = ("Query futures submitted: {} out of {}. Current "
"memory usage is {:.3f} GB out of {:.3f} GB total."
.format(i + 1, len(slices), mem.used / 1e9,
mem.total / 1e9))
logger.info(msg)
logger.info(f"Submitted all query futures in {dt.now() - now}.")
for i, future in enumerate(as_completed(futures)):
idx = futures[future]
mem = psutil.virtual_memory()
msg = ("Query futures completed: {} out of {}. Current memory "
"usage is {:.3f} GB out of {:.3f} GB total."
.format(i + 1, len(futures), mem.used / 1e9,
mem.total / 1e9))
logger.info(msg)
try:
future.result()
except Exception as e:
msg = ("Failed to query coordinate chunk with index={}"
.format(idx))
logger.exception(msg)
raise RuntimeError(msg) from e
[docs]
def save_query(self, s_slice):
"""Save tree query for coordinates specified by given spatial slice"""
out = self.tree.query(self.get_spatial_chunk(s_slice),
k=self.k_neighbors)
self.distances[s_slice] = out[0]
self.indices[s_slice] = out[1]
[docs]
def get_spatial_chunk(self, s_slice):
"""Get list of coordinates in target_meta specified by the given
spatial slice
Parameters
----------
s_slice : slice
Slice specifying which spatial indices in the target grid
should be selected. This selects `n_points` from the target
grid.
Returns
-------
ndarray
Array of `n_points` in `target_meta` selected by `s_slice`.
"""
out = self.target_meta.iloc[s_slice][["latitude", "longitude"]].values
return np.radians(out)
[docs]
@classmethod
def run(cls, source_meta, target_meta, source_data, k_neighbors=4,
n_chunks=100, max_workers=None, min_distance=1e-12,
leaf_size=4):
"""Regrid data using inverse distance weighting.
Parameters
----------
source_meta : :class:`pandas.DataFrame`
Set of coordinates for source grid. Must contain "latitude"
and "longitude" columns representing the coordinates
(in degrees).
target_meta : :class:`pandas.DataFrame`
Set of coordinates for target grid. Must contain "latitude"
and "longitude" columns representing the coordinates
(in degrees).
source_data : ndarray
Spatiotemporal data to regrid to `target_meta` coordinate
grid. Data can be flattened in the spatial dimension to
match the `target_meta` or be in a 2D spatial grid, e.g.:
(spatial, temporal) or (spatial_1, spatial_2, temporal)
leaf_size : int, optional
Leaf size for :class:`~sklearn.neighbors.BallTree` instance.
By default, ``4``.
k_neighbors : int, optional
Number of nearest neighbors to use for interpolation.
By default, ``4``.
n_chunks : int
Number of spatial chunks to use for tree queries. The total
number of points in the `target_meta` will be split into
`n_chunks`, and the points in each chunk will be queried at
the same time. By default, ``100``.
max_workers : int, optional
Max number of workers to use for running all tree queries
needed to build the full set of indices and distances for
each `target_meta` coordinate. By default, ``None``, which
uses all available CPU cores.
min_distance : float, optional
Minimum distance to use for inverse-weighted distances
calculation to avoid diving by 0. By default, ``1e-12``.
"""
regridder = cls(source_meta=source_meta, target_meta=target_meta,
leaf_size=leaf_size, k_neighbors=k_neighbors,
n_chunks=n_chunks, max_workers=max_workers,
min_distance=min_distance)
regridder.get_all_queries(max_workers)
return regridder(source_data)
[docs]
class CachedRegridder(_InterpolationMixin):
"""Interpolate from one grid to another using cached dists and inds."""
def __init__(self, cache_pattern, min_distance=1e-12):
"""
Parameters
----------
cache_pattern : str
Filepath pattern for cached distances and indices to load.
Should be of the form ``'./{array_name}.pkl'`` where
`array_name` will internally be replaced with either
``'distances'`` or ``'indices'``.'
min_distance : float, optional
Minimum distance to use for inverse-weighted distances
calculation to avoid diving by 0. By default, ``1e-12``.
"""
self.distances, self.indices = self.load_cache(cache_pattern)
self.min_distance = min_distance
[docs]
@staticmethod
def load_cache(cache_pattern):
"""Load cached indices and distances from ball tree query.
Parameters
----------
cache_pattern : str
Filepath pattern for cached distances and indices to load.
Should be of the form ``'./{array_name}.pkl'`` where
`array_name` will internally be replaced with either
``'distances'`` or ``'indices'``.
Returns
-------
distances, indices : ndarray
Arrays of distances and indices output by the ball tree.
"""
distance_file = cache_pattern.format(array_name='distances')
index_file = cache_pattern.format(array_name='indices')
with open(distance_file, 'rb') as f:
distances = pickle.load(f)
with open(index_file, 'rb') as f:
indices = pickle.load(f)
logger.info('Loaded cache files: %s, %s', distance_file, index_file)
return distances, indices
[docs]
@classmethod
def build_cache(cls, cache_pattern, *args, **kwargs):
"""Cache distances and indices from ball tree query.
Parameters
----------
cache_pattern : str
Filepath pattern used to cache distances and indices.
Should be of the form ``'./{array_name}.pkl'`` where
`array_name` will internally be replaced with either
``'distances'`` or ``'indices'``.
*args, **kwargs
Arguments followed by keyword arguments that can be used to
initialize :class:`Regridder`. The ``Regridder`` instance
will generate the distance and index arrays to be cached.
"""
distance_file = cache_pattern.format(array_name='distances')
index_file = cache_pattern.format(array_name='indices')
regridder = Regridder(*args, **kwargs)
with open(distance_file, 'wb') as f:
pickle.dump(regridder.distances, f, protocol=4)
with open(index_file, 'wb') as f:
pickle.dump(regridder.indices, f, protocol=4)
logger.info('Saved cache files: %s, %s', distance_file, index_file)
def _compute_weights(distances, min_distance):
"""Compute inverse weights from distance values. """
dists = np.array(distances, dtype=np.float32)
mask = dists < min_distance
dists[mask] = min_distance
if mask.sum() > 0:
logger.info("%d of %d neighbor distances are within %.5f.",
np.sum(mask), np.prod(mask.shape),
min_distance)
weights = 1 / dists
weights[mask.any(axis=1), :] = (np.eye(1, dists.shape[1])
.flatten())
return weights / np.sum(weights, axis=-1)[:, None]