Source code for revrt.costs.layer_creator

"""Build friction or barrier layers from raster and vector data"""

import logging
from pathlib import Path
from warnings import warn

import numpy as np
import dask.array as da

from revrt.costs.base import BaseLayerCreator
from revrt.utilities import (
    file_full_path,
    load_data_using_layer_file_profile,
    save_data_using_layer_file_profile,
    log_mem,
)
from revrt.utilities.raster import rasterize_shape_file
from revrt.constants import DEFAULT_DTYPE, ALL, METERS_IN_MILE
from revrt.exceptions import revrtAttributeError, revrtValueError
from revrt.warn import revrtWarning

logger = logging.getLogger(__name__)
TIFF_EXTENSIONS = {".tif", ".tiff"}
SHP_EXTENSIONS = {".shp", ".gpkg"}


[docs] class LayerCreator(BaseLayerCreator): """Build layer based on tiff and user config""" def __init__( self, io_handler, masks, input_layer_dir=".", output_tiff_dir=".", dtype=DEFAULT_DTYPE, ): """ Parameters ---------- io_handler : :class:`LayeredFile` Layered file IO handler. masks : Masks Masks instance that can be used to retrieve multiple types of masks. input_layer_dir : path-like, optional Directory to search for input layers in, if not found in current directory. By default, ``'.'``. output_tiff_dir : path-like, optional Directory where cost layers should be saved as GeoTIFF. By default, ``"."``. dtype : np.dtype, optional Data type for final dataset. By default, ``float32``. """ self._masks = masks super().__init__( io_handler=io_handler, input_layer_dir=input_layer_dir, output_tiff_dir=output_tiff_dir, dtype=dtype, )
[docs] def build( self, layer_name, build_config, values_are_costs_per_mile=False, write_to_file=True, description=None, tiff_chunks="file", nodata=None, lock=None, **profile_kwargs, ): """Combine multiple GeoTIFFs and vectors to a raster layer Parameters ---------- layer_name : str Name of layer to use in H5 and for output tiff. build_config : LayerBuildComponents Dict of LayerBuildConfig keyed by GeoTIFF/vector filenames. values_are_costs_per_mile : bool, default=False Option to convert values into costs per cell under the assumption that the resulting values are costs in $/mile. By default, ``False``, which writes raw values to TIFF/H5. write_to_file : bool, default=True Option to write the layer to file after creation. ..IMPORTANT:: This will overwrite existing layers with the same name already in the file. By default, ``True``. description : str, optional Optional description to store with this layer in the H5 file. By default, ``None``. tiff_chunks : int | str, default="file" Chunk size to use when reading the GeoTIFF file. This will be passed down as the ``chunks`` argument to :meth:`rioxarray.open_rasterio`. By default, ``"file"``. nodata : int | float, optional Optional nodata value for output rasters. This value will be added to the layer's attributes meta dictionary under the "nodata" key. lock : bool | `dask.distributed.Lock`, optional Lock to use to write data to GeoTIFF using dask. If not supplied, a single process is used for writing data to disk. By default, ``None``. **profile_kwargs Additional keyword arguments to pass into writing output rasters. The following attributes ar ignored (they are set using properties of the :class:`LayeredFile`): - nodata - transform - crs - count - width - height """ tiff_filename = self._process_and_write_as_tiff( layer_name=layer_name, build_config=build_config, values_are_costs_per_mile=values_are_costs_per_mile, tiff_chunks=tiff_chunks, nodata=nodata, lock=lock, **profile_kwargs, ) if write_to_file: out = load_data_using_layer_file_profile( layer_fp=self._io_handler.fp, geotiff=tiff_filename, tiff_chunks=tiff_chunks, layer_dirs=[self.input_layer_dir, self.output_tiff_dir], band_index=0, ) log_mem() logger.debug("Writing %r to '%s'", layer_name, self._io_handler.fp) self._io_handler.write_layer( out, layer_name, description=description, overwrite=True ) log_mem()
def _process_and_write_as_tiff( self, layer_name, build_config, values_are_costs_per_mile=False, tiff_chunks="file", nodata=None, lock=None, **profile_kwargs, ): layer_name = layer_name.replace(".tif", "").replace(".tiff", "") logger.debug("Combining %s layers", layer_name) log_mem() result = da.zeros(self.shape, dtype=self._dtype, chunks=self.chunks) fi_layers = {} logger.debug("Initialized zeros") log_mem() for fname, config in build_config.items(): if config.forced_inclusion: fi_layers[fname] = config continue logger.debug("Processing %s with config %s", fname, config) if Path(fname).suffix.lower() in TIFF_EXTENSIONS: temp = self._process_raster_layer( fname, config, tiff_chunks=tiff_chunks ) result += temp elif Path(fname).suffix.lower() in SHP_EXTENSIONS: temp = self._process_vector_layer(fname, config) result += temp else: msg = f"Unsupported file extension on {fname!r}" raise revrtValueError(msg) log_mem() result = self._process_forced_inclusions( result, fi_layers, tiff_chunks=tiff_chunks ) logger.debug("After forced inclusions") log_mem() if values_are_costs_per_mile: result = result / METERS_IN_MILE * self.cell_size log_mem() result = result.astype(self._dtype) out_filename = self.output_tiff_dir / f"{layer_name}.tif" logger.debug( "Writing combined %s layers to %s", layer_name, out_filename ) log_mem() save_data_using_layer_file_profile( layer_fp=self._io_handler.fp, data=result, geotiff=out_filename, nodata=nodata, lock=lock, **profile_kwargs, ) return out_filename def _process_raster_layer(self, fname, config, tiff_chunks="file"): """Create the desired layer from the input file""" _check_tiff_layer_config(config, fname) data = load_data_using_layer_file_profile( layer_fp=self._io_handler.fp, geotiff=fname, tiff_chunks=tiff_chunks, layer_dirs=[self.input_layer_dir, self.output_tiff_dir], band_index=0, ) return self._process_raster_data(data, config) def _process_raster_data(self, data, config): """Create the desired layer from the data array""" if config.global_value is not None: return self._process_global_raster_value(config) if config.bins is not None: return self._process_raster_bins(config, data) if config.pass_through: return self._pass_through_raster(config, data) return self._process_raster_map(config, data) def _process_global_raster_value(self, config): """Create the desired layer from the global value""" temp = da.full( self.shape, fill_value=config.global_value, dtype=self._dtype, chunks=self.chunks, ) return self._apply_mask(config, temp) def _process_raster_bins(self, config, data): """Create the desired layer from the input file using bins""" _validate_bin_range(config.bins) _validate_bin_continuity(config.bins) processed = da.zeros(self.shape, dtype=self._dtype, chunks=self.chunks) if config.extent != ALL: mask = self._get_mask(config.extent) for i, interval in enumerate(config.bins): logger.debug( "Calculating layer values for bin %d/%d: %r", i + 1, len(config.bins), interval, ) temp = da.where( np.logical_and(data >= interval.min, data < interval.max), interval.value, 0, ) if config.extent == ALL: processed += temp continue processed = da.where(mask, processed + temp, processed) return processed def _pass_through_raster(self, config, data): """Process raster by passing it through without modification""" return self._apply_mask(config, data) def _process_raster_map(self, config, data): """Create the desired layer from the input file using a map""" temp = da.zeros(self.shape, dtype=self._dtype, chunks=self.chunks) for key, val in config.map.items(): temp = da.where(data == key, val, temp) return self._apply_mask(config, temp) def _process_vector_layer(self, fname, config): """Rasterize a vector layer""" if config.rasterize is None: msg = ( f"{fname!r} is a vector but the config is missing " f'key "rasterize": {config}' ) raise revrtValueError(msg) kwargs = { k: v for k, v in self._io_handler.profile.items() if k != "crs" } if config.rasterize.reproject: kwargs["dest_crs"] = self._io_handler.profile["crs"] fname = file_full_path(fname, self.input_layer_dir) temp = rasterize_shape_file( fname, buffer_dist=config.rasterize.buffer, burn_value=config.rasterize.value, all_touched=config.rasterize.all_touched, dtype=self._dtype, **kwargs, ) return self._apply_mask(config, temp) def _apply_mask(self, config, data): """Apply the mask to the data based on the config extent""" if config.extent == ALL: return data mask = self._get_mask(config.extent) return da.where(mask, data, 0) def _process_forced_inclusions(self, data, fi_layers, tiff_chunks="file"): """Use forced inclusion (FI) layers to remove barriers/friction Any value > 0 in the FI layers will result in a 0 in the corresponding cell in the returned raster. """ fi = da.zeros(self.shape, dtype=self._dtype, chunks=self.chunks) for fname, config in fi_layers.items(): if Path(fname).suffix.lower() not in TIFF_EXTENSIONS: msg = ( f"Forced inclusion file {fname!r} does not end with .tif." " GeoTIFFs are the only format allowed for forced " "inclusions." ) raise revrtValueError(msg) global_value_given = config.global_value is not None map_given = config.map is not None range_given = config.bins is not None rasterize_given = config.rasterize is not None bad_input_given = ( global_value_given or map_given or range_given or rasterize_given ) if bad_input_given: msg = ( "`global_value`, `map`, `bins`, and `rasterize` are " "not allowed if `forced_inclusion` is True, but one " f"was found in config: {fname!r}: {config}" ) raise revrtValueError(msg) # Past guard clauses, process FI if config.extent != ALL: mask = self._get_mask(config.extent) temp = load_data_using_layer_file_profile( layer_fp=self._io_handler.fp, geotiff=fname, tiff_chunks=tiff_chunks, layer_dirs=[self.input_layer_dir, self.output_tiff_dir], band_index=0, ) if config.extent == ALL: fi += temp else: fi = da.where(mask, fi + temp, fi) return da.where(fi > 0, 0, data) def _get_mask(self, extent): """Get mask by requested extent""" if extent == ALL: msg = f"Mask for extent of {extent!r} is unnecessary" raise revrtAttributeError(msg) if extent == "wet": mask = self._masks.wet_mask elif extent == "wet+": mask = self._masks.wet_plus_mask elif extent == "dry": mask = self._masks.dry_mask elif extent == "dry+": mask = self._masks.dry_plus_mask elif extent == "landfall": mask = self._masks.landfall_mask else: msg = f"Unknown mask type: {extent!r}" raise revrtAttributeError(msg) return mask
def _check_tiff_layer_config(config, fname): """Check if a LayerBuildConfig is valid for a GeoTIFF""" if config.rasterize is not None: msg = ( f"'rasterize' is only for vectors. Found in {fname!r} config: " f"{config}" ) raise revrtValueError(msg) mutex_entries = [config.map, config.bins, config.global_value] num_entries = sum(entry is not None for entry in mutex_entries) num_entries += int(config.pass_through) if num_entries > 1: msg = ( "Keys 'global_value', 'map', 'bins', and " "'pass_through' are mutually exclusive but " f"more than one was found in {fname!r} raster config: {config}" ) raise revrtValueError(msg) if num_entries < 1: msg = ( "Either 'global_value', 'map', 'bins', and " "'pass_through' must be specified for a raster, " f"but none were found in {fname!r} config: {config}" ) raise revrtValueError(msg) def _validate_bin_range(bins): """Check for correctness in bin range""" for input_bin in bins: if input_bin.min > input_bin.max: msg = f"Min is greater than max for bin config {input_bin}." raise revrtAttributeError(msg) if input_bin.min == float("-inf") and input_bin.max == float("inf"): msg = ( "Bin covers all possible values, did you forget to set " f"min or max? {input_bin}" ) warn(msg, revrtWarning) def _validate_bin_continuity(bins): """Warn user of potential gaps in bin range continuity""" sorted_bins = sorted(bins, key=lambda x: x.min) last_max = float("-inf") for i, input_bin in enumerate(sorted_bins): if input_bin.min < last_max: last_bin = sorted_bins[i - 1] if i > 0 else "-infinity" msg = f"Overlapping bins detected between bin {last_bin} and {bin}" warn(msg, revrtWarning) if input_bin.min > last_max: last_bin = sorted_bins[i - 1] if i > 0 else "-infinity" msg = f"Gap detected between bin {last_bin} and {input_bin}" warn(msg, revrtWarning) if i + 1 == len(sorted_bins) and input_bin.max < float("inf"): msg = f"Gap detected between bin {input_bin} and infinity" warn(msg, revrtWarning) last_max = input_bin.max