Source code for satpy.resample.kdtree

"""Resamplers based on the kdtree algorightm."""

import os
import warnings
from logging import getLogger

import dask.array as da
import numpy as np
import xarray as xr
import zarr
from pyresample.resampler import BaseResampler as PRBaseResampler

from satpy.resample.base import _update_resampled_coords
from satpy.utils import get_legacy_chunk_size

LOG = getLogger(__name__)

CHUNK_SIZE = get_legacy_chunk_size()

NN_COORDINATES = {"valid_input_index": ("y1", "x1"),
                  "valid_output_index": ("y2", "x2"),
                  "index_array": ("y2", "x2", "z2")}
BIL_COORDINATES = {"bilinear_s": ("x1", ),
                   "bilinear_t": ("x1", ),
                   "slices_x": ("x1", "n"),
                   "slices_y": ("x1", "n"),
                   "mask_slices": ("x1", "n"),
                   "out_coords_x": ("x2", ),
                   "out_coords_y": ("y2", )}

[docs] class KDTreeResampler(PRBaseResampler): """Resample using a KDTree-based nearest neighbor algorithm. This resampler implements on-disk caching when the `cache_dir` argument is provided to the `resample` method. This should provide significant performance improvements on consecutive resampling of geostationary data. It is not recommended to provide `cache_dir` when the `mask` keyword argument is provided to `precompute` which occurs by default for `SwathDefinition` source areas. Args: cache_dir (str): Long term storage directory for intermediate results. mask (bool): Force resampled data's invalid pixel mask to be used when searching for nearest neighbor pixels. By default this is True for SwathDefinition source areas and False for all other area definition types. radius_of_influence (float): Search radius cut off distance in meters epsilon (float): Allowed uncertainty in meters. Increasing uncertainty reduces execution time. """
[docs] def __init__(self, source_geo_def, target_geo_def): """Init KDTreeResampler.""" super(KDTreeResampler, self).__init__(source_geo_def, target_geo_def) self.resampler = None self._index_caches = {}
[docs] def precompute(self, mask=None, radius_of_influence=None, epsilon=0, cache_dir=None, **kwargs): """Create a KDTree structure and store it for later use. Note: The `mask` keyword should be provided if geolocation may be valid where data points are invalid. """ from pyresample.kd_tree import XArrayResamplerNN del kwargs if mask is not None and cache_dir is not None: LOG.warning("Mask and cache_dir both provided to nearest " "resampler. Cached parameters are affected by " "masked pixels. Will not cache results.") cache_dir = None if radius_of_influence is None and not hasattr(self.source_geo_def, "geocentric_resolution"): radius_of_influence = self._adjust_radius_of_influence(radius_of_influence) kwargs = dict(source_geo_def=self.source_geo_def, target_geo_def=self.target_geo_def, radius_of_influence=radius_of_influence, neighbours=1, epsilon=epsilon) if self.resampler is None: # FIXME: We need to move all of this caching logic to pyresample self.resampler = XArrayResamplerNN(**kwargs) try: self.load_neighbour_info(cache_dir, mask=mask, **kwargs) LOG.debug("Read pre-computed kd-tree parameters") except IOError: LOG.debug("Computing kd-tree parameters") self.resampler.get_neighbour_info(mask=mask) self.save_neighbour_info(cache_dir, mask=mask, **kwargs)
[docs] def _adjust_radius_of_influence(self, radius_of_influence): """Adjust radius of influence.""" warnings.warn( "Upgrade 'pyresample' for a more accurate default 'radius_of_influence'.", stacklevel=3 ) try: radius_of_influence = self.source_geo_def.lons.resolution * 3 except AttributeError: try: radius_of_influence = max(abs(self.source_geo_def.pixel_size_x), abs(self.source_geo_def.pixel_size_y)) * 3 except AttributeError: radius_of_influence = 1000 except TypeError: radius_of_influence = 10000 return radius_of_influence
[docs] def _apply_cached_index(self, val, idx_name, persist=False): """Reassign resampler index attributes.""" if isinstance(val, np.ndarray): val = da.from_array(val, chunks=CHUNK_SIZE) elif persist and isinstance(val, da.Array): val = val.persist() setattr(self.resampler, idx_name, val) return val
[docs] def load_neighbour_info(self, cache_dir, mask=None, **kwargs): """Read index arrays from either the in-memory or disk cache.""" mask_name = getattr(mask, "name", None) cached = {} for idx_name in NN_COORDINATES: if mask_name in self._index_caches: cached[idx_name] = self._apply_cached_index( self._index_caches[mask_name][idx_name], idx_name) elif cache_dir: cache = self._load_neighbour_info_from_cache( cache_dir, idx_name, mask_name, **kwargs) cache = self._apply_cached_index(cache, idx_name) cached[idx_name] = cache else: raise IOError self._index_caches[mask_name] = cached
[docs] def _load_neighbour_info_from_cache(self, cache_dir, idx_name, mask_name, **kwargs): try: filename = self._create_cache_filename( cache_dir, prefix="nn_lut-", mask=mask_name, **kwargs) fid = zarr.open(filename, mode="r") cache = np.array(fid[idx_name]) if idx_name == "valid_input_index": # valid input index array needs to be boolean cache = cache.astype(bool) except ValueError: raise IOError return cache
[docs] def save_neighbour_info(self, cache_dir, mask=None, **kwargs): """Cache resampler's index arrays if there is a cache dir.""" if cache_dir: mask_name = getattr(mask, "name", None) cache = self._read_resampler_attrs() filename = self._create_cache_filename( cache_dir, prefix="nn_lut-", mask=mask_name, **kwargs) LOG.info("Saving kd_tree neighbour info to %s", filename) zarr_out = xr.Dataset() for idx_name, coord in NN_COORDINATES.items(): # update the cache in place with persisted dask arrays cache[idx_name] = self._apply_cached_index(cache[idx_name], idx_name, persist=True) zarr_out[idx_name] = (coord, cache[idx_name]) # Write indices to Zarr file zarr_out.to_zarr(filename) self._index_caches[mask_name] = cache # Delete the kdtree, it's not needed anymore self.resampler.delayed_kdtree = None
[docs] def _read_resampler_attrs(self): """Read certain attributes from the resampler for caching.""" return {attr_name: getattr(self.resampler, attr_name) for attr_name in NN_COORDINATES}
[docs] def compute(self, data, weight_funcs=None, fill_value=np.nan, with_uncert=False, **kwargs): """Resample data.""" del kwargs LOG.debug("Resampling %s", str(data.name)) res = self.resampler.get_sample_from_neighbour_info(data, fill_value) return _update_resampled_coords(data, res, self.target_geo_def)
[docs] class BilinearResampler(PRBaseResampler): """Resample using bilinear interpolation. This resampler implements on-disk caching when the `cache_dir` argument is provided to the `resample` method. This should provide significant performance improvements on consecutive resampling of geostationary data. Args: cache_dir (str): Long term storage directory for intermediate results. radius_of_influence (float): Search radius cut off distance in meters epsilon (float): Allowed uncertainty in meters. Increasing uncertainty reduces execution time. reduce_data (bool): Reduce the input data to (roughly) match the target area. """
[docs] def __init__(self, source_geo_def, target_geo_def): """Init BilinearResampler.""" super(BilinearResampler, self).__init__(source_geo_def, target_geo_def) self.resampler = None
[docs] def precompute(self, mask=None, radius_of_influence=50000, epsilon=0, reduce_data=True, cache_dir=False, **kwargs): """Create bilinear coefficients and store them for later use.""" try: from pyresample.bilinear import XArrayBilinearResampler except ImportError: from pyresample.bilinear import XArrayResamplerBilinear as XArrayBilinearResampler del kwargs del mask if self.resampler is None: kwargs = dict(source_geo_def=self.source_geo_def, target_geo_def=self.target_geo_def, radius_of_influence=radius_of_influence, neighbours=32, epsilon=epsilon) self.resampler = XArrayBilinearResampler(**kwargs) try: self.load_bil_info(cache_dir, **kwargs) LOG.debug("Loaded bilinear parameters") except IOError: LOG.debug("Computing bilinear parameters") self.resampler.get_bil_info() LOG.debug("Saving bilinear parameters.") self.save_bil_info(cache_dir, **kwargs)
[docs] def load_bil_info(self, cache_dir, **kwargs): """Load bilinear resampling info from cache directory.""" if cache_dir: filename = self._create_cache_filename(cache_dir, prefix="bil_lut-", **kwargs) try: self.resampler.load_resampling_info(filename) except AttributeError: warnings.warn( "Bilinear resampler can't handle caching, " "please upgrade Pyresample to 0.17.0 or newer.", stacklevel=2 ) raise IOError else: raise IOError
[docs] def save_bil_info(self, cache_dir, **kwargs): """Save bilinear resampling info to cache directory.""" if cache_dir: filename = self._create_cache_filename(cache_dir, prefix="bil_lut-", **kwargs) # There are some old caches, move them out of the way if os.path.exists(filename): _move_existing_caches(cache_dir, filename) LOG.info("Saving BIL neighbour info to %s", filename) try: self.resampler.save_resampling_info(filename) except AttributeError: warnings.warn( "Bilinear resampler can't handle caching, " "please upgrade Pyresample to 0.17.0 or newer.", stacklevel=2 )
[docs] def compute(self, data, fill_value=None, **kwargs): """Resample the given data using bilinear interpolation.""" del kwargs if fill_value is None: fill_value = data.attrs.get("_FillValue") target_shape = self.target_geo_def.shape res = self.resampler.get_sample_from_bil_info(data, fill_value=fill_value, output_shape=target_shape) return _update_resampled_coords(data, res, self.target_geo_def)
[docs] def _move_existing_caches(cache_dir, filename): """Move existing cache files out of the way.""" import os import shutil old_cache_dir = os.path.join(cache_dir, "moved_by_satpy") try: os.makedirs(old_cache_dir) except FileExistsError: pass try: shutil.move(filename, old_cache_dir) except shutil.Error: os.remove(os.path.join(old_cache_dir, os.path.basename(filename))) shutil.move(filename, old_cache_dir) LOG.warning("Old cache file was moved to %s", old_cache_dir)
[docs] def get_resampler_classes(): """Get resampler classes based on kdtree.""" return { "kd_tree": KDTreeResampler, "nearest": KDTreeResampler, "bilinear": BilinearResampler }