# -*- coding: utf-8 -*-
"""
Class to handle geotiff input files.
Created on Thu Jun 20 09:43:34 2019
@author: gbuster
"""
import os
import logging
import rasterio
import numpy as np
import pandas as pd
from affine import Affine
from pyproj import Transformer
from rex.utilities.parse_keys import parse_keys
from reVX.utilities.exceptions import GeoTiffKeyError
logger = logging.getLogger(__name__)
[docs]class Geotiff:
"""GeoTIFF handler object."""
def __init__(self, fpath, chunks=(128, 128)):
"""
Parameters
----------
fpath : str
Path to .tiff file.
chunks : tuple
GeoTIFF chunk (tile) shape/size.
"""
self._fpath = fpath
self._iarr = None
self._src = rasterio.open(self._fpath, chunks=chunks)
self._profile = dict(self._src.profile)
self._profile["transform"] = self._profile["transform"][:6]
self._profile["crs"] = self._profile["crs"].to_proj4()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
if type is not None:
raise
def __len__(self):
"""Total number of pixels in the GeoTiff."""
return self.n_rows * self.n_cols
def __getitem__(self, keys):
"""Retrieve data from the GeoTIFF object.
Example, get meta data and layer-0 data for rows 0 through 128 and
columns 128 through 256.
meta = geotiff['meta', 0:128, 128:256]
data = geotiff[0, 0:128, 128:256]
Parameters
----------
keys : tuple
Slicing args similar to a numpy array slice. See examples above.
"""
ds, ds_slice = parse_keys(keys)
out = None
if isinstance(ds, str):
if ds == 'meta':
out = self._get_meta(*ds_slice)
elif ds.lower().startswith('lat'):
out = self._get_lat_lon(*ds_slice)[0]
elif ds.lower().startswith('lon'):
out = self._get_lat_lon(*ds_slice)[1]
if out is None:
out = self._get_data(ds, *ds_slice)
return out
@property
def profile(self):
"""
GeoTiff geospatial profile
Returns
-------
_profile : dict
Dictionary of geo-spatial attributes needed to create GeoTiff
"""
return self._profile
@property
def dtype(self):
"""
GeoTiff array dtype
Returns
-------
dtype : str
Dtype of data in GeoTiff
"""
return self.profile["dtype"]
@property
def iarr(self):
"""Get an array of 1D index values for the flattened geotiff extent.
Returns
-------
iarr : np.ndarray
Uint array with same shape as geotiff extent, representing the 1D
index values if the geotiff extent was flattened
(with default flatten order 'C')
"""
if self._iarr is None:
self._iarr = np.arange(len(self), dtype=np.uint32)
self._iarr = self._iarr.reshape(self.shape)
return self._iarr
@property
def tiff_shape(self):
"""
Tiff array shape (bands, y, x)
Returns
-------
shape : tuple
(bands, y, x)
"""
return (self.bands, *self.shape)
@property
def shape(self):
"""Get the Geotiff shape tuple (n_rows, n_cols).
Returns
-------
shape : tuple
2-entry tuple representing the full GeoTiff shape.
"""
return self._src.shape
@property
def n_rows(self):
"""Get the number of Geotiff rows.
Returns
-------
n_rows : int
Number of row entries in the full geotiff.
"""
return self.shape[0]
@property
def n_cols(self):
"""Get the number of Geotiff columns.
Returns
-------
n_cols : int
Number of column entries in the full geotiff.
"""
return self.shape[1]
@property
def bands(self):
"""
Get number of GeoTiff bands
Returns
-------
bands : int
"""
return self._src.count
@property
def lat_lon(self):
"""
Get latitude and longitude coordinate arrays
Returns
-------
tuple
"""
return self._get_lat_lon(slice(None), slice(None))
@property
def latitude(self):
"""
Get latitude coordinates array
Returns
-------
ndarray
"""
return self['lat']
@property
def longitude(self):
"""
Get longitude coordinates array
Returns
-------
ndarray
"""
return self['lon']
@property
def meta(self):
"""
Lat lon to y, x coordinate mapping
Returns
-------
pd.DataFrame
"""
return self['meta']
@property
def values(self):
"""
Full DataArray in [bands, y, x] dimensions
Returns
-------
ndarray
"""
return self._src.read()
@staticmethod
def _unpack_slices(*yx_slice):
"""Get the flattened geotiff layer data.
Parameters
----------
*yx_slice : tuple
Slicing args for data
Returns
-------
y_slice : slice
Row slice.
x_slice : slice
Col slice.
"""
if len(yx_slice) == 1:
y_slice = yx_slice[0]
x_slice = slice(None)
elif len(yx_slice) == 2:
y_slice = yx_slice[0]
x_slice = yx_slice[1]
else:
raise GeoTiffKeyError('Cannot do 3D slicing on GeoTiff meta.')
return y_slice, x_slice
@staticmethod
def _get_meta_inds(x_slice, y_slice):
"""Get the row and column indices associated with lat/lon slices.
Parameters
----------
x_slice : slice
Column slice corresponding to the extracted lon values.
y_slice : slice
Row slice corresponding to the extracted lat values.
Returns
-------
row_ind : np.ndarray
1D array of the row indices corresponding to the lat/lon arrays
once mesh-gridded and flattened
col_ind : np.ndarray
1D array of the col indices corresponding to the lat/lon arrays
once mesh-gridded and flattened
"""
if y_slice.start is None:
y_slice = slice(0, y_slice.stop)
if x_slice.start is None:
x_slice = slice(0, x_slice.stop)
x_len = x_slice.stop - x_slice.start
y_len = y_slice.stop - y_slice.start
col_ind = np.arange(x_slice.start, x_slice.start + x_len)
row_ind = np.arange(y_slice.start, y_slice.start + y_len)
col_ind = col_ind.astype(np.uint32)
row_ind = row_ind.astype(np.uint32)
col_ind, row_ind = np.meshgrid(col_ind, row_ind)
col_ind = col_ind.flatten()
row_ind = row_ind.flatten()
return row_ind, col_ind
def _get_meta(self, *ds_slice):
"""Get the geotiff meta dataframe in standard WGS84 projection.
Parameters
----------
*ds_slice : tuple
Slicing args for meta data.
Returns
-------
meta : pd.DataFrame
Flattened meta data with same format as reV resource meta data.
"""
y_slice, x_slice = self._unpack_slices(*ds_slice)
row_ind, col_ind = self._get_meta_inds(x_slice, y_slice)
lat, lon = self._get_lat_lon(*ds_slice)
lon = lon.flatten()
lat = lat.flatten()
meta = pd.DataFrame({'latitude': lat.astype(np.float32),
'longitude': lon.astype(np.float32),
'row_ind': row_ind, 'col_ind': col_ind})
return meta
# pylint: disable=all
def _get_lat_lon(self, *ds_slice):
"""
Get the geotiff latitude and longitude coordinates
Parameters
----------
*ds_slice : tuple
Slicing args for latitude and longitude arrays
Returns
-------
lat : ndarray
Projected latitude coordinates
lon : ndarray
Projected longitude coordinates
"""
y_slice, x_slice = self._unpack_slices(*ds_slice)
cols, rows = np.meshgrid(np.arange(self.n_cols),
np.arange(self.n_rows))
pixel_center_translation = Affine.translation(0.5, 0.5)
adjusted_transform = self._src.transform * pixel_center_translation
lon, lat = adjusted_transform * [cols[y_slice, x_slice],
rows[y_slice, x_slice]]
transformer = Transformer.from_crs(self._src.profile["crs"],
'epsg:4326', always_xy=True)
lon, lat = transformer.transform(lon, lat)
return lat.astype(np.float32), lon.astype(np.float32)
def _get_data(self, ds, *ds_slice):
"""Get the flattened geotiff layer data.
Parameters
----------
ds : int
Layer to get data from
*ds_slice : tuple
Slicing args for data
Returns
-------
data : np.ndarray
1D array of flattened data corresponding to meta data.
"""
y_slice, x_slice = self._unpack_slices(*ds_slice)
if x_slice.stop is None:
x_slice = slice(x_slice.start, self.shape[1], x_slice.step)
if y_slice.stop is None:
y_slice = slice(y_slice.start, self.shape[0], y_slice.step)
window = rasterio.windows.Window.from_slices(y_slice, x_slice)
data = self._src.read(ds + 1, window=window).flatten()
return data
[docs] def close(self):
"""Close the rasterio source object"""
self._src.close()
[docs] @staticmethod
def write(out_fp, profile, values, dtype=None):
"""Write values to GeoTIFF file with given profile.
Parameters
----------
out_fp : str
Path to GeoTIFF output file to save data to.
profile : dict
GeoTIFF profile (attributes).
values : ndarray
GeoTIFF data to save.
dtype : str, optional
Type of data being stored. If ``None``, the data dtype is
inferred from the `values` input itself.
"""
out_dir = os.path.dirname(out_fp)
if out_dir and not os.path.exists(out_dir):
logger.debug("Creating %s", out_dir)
os.makedirs(out_dir)
if values.ndim < 3:
values = np.expand_dims(values, 0)
dtype = dtype or values.dtype.name
profile['dtype'] = dtype
if "nodata" not in profile:
if np.issubdtype(dtype, np.integer):
dtype_max = np.iinfo(dtype).max
else:
dtype_max = np.finfo(dtype).max
profile['nodata'] = dtype_max
with rasterio.open(out_fp, 'w', **profile) as f:
f.write(values)
logger.debug('%s created', out_fp)