Source code for satpy.resample.native

"""Native resampler."""

import warnings
from math import lcm  # type: ignore

import dask.array as da
import numpy as np
import xarray as xr
from pyresample.resampler import BaseResampler as PRBaseResampler

from satpy.resample.base import _update_resampled_coords
from satpy.utils import PerformanceWarning, get_legacy_chunk_size

CHUNK_SIZE = get_legacy_chunk_size()


[docs] class NativeResampler(PRBaseResampler): """Expand or reduce input datasets to be the same shape. If data is higher resolution (more pixels) than the destination area then data is averaged to match the destination resolution. If data is lower resolution (less pixels) than the destination area then data is repeated to match the destination resolution. This resampler does not perform any caching or masking due to the simplicity of the operations. """
[docs] def resample(self, data, cache_dir=None, mask_area=False, **kwargs): """Run NativeResampler.""" # use 'mask_area' with a default of False. It wouldn't do anything. return super(NativeResampler, self).resample(data, cache_dir=cache_dir, mask_area=mask_area, **kwargs)
[docs] @classmethod def _expand_reduce(cls, d_arr, repeats): """Expand reduce.""" d_arr = _ensure_dask_array(d_arr) if all(x == 1 for x in repeats.values()): return d_arr if all(x >= 1 for x in repeats.values()): return _replicate(d_arr, repeats) if all(x <= 1 for x in repeats.values()): # reduce y_size = 1. / repeats[0] x_size = 1. / repeats[1] return _aggregate(d_arr, y_size, x_size) raise ValueError("Must either expand or reduce in both " "directions")
[docs] def compute(self, data, expand=True, **kwargs): """Resample data with NativeResampler.""" if isinstance(self.target_geo_def, (list, tuple)): # find the highest/lowest area among the provided test_func = max if expand else min target_geo_def = test_func(self.target_geo_def, key=lambda x: x.shape) else: target_geo_def = self.target_geo_def # convert xarray backed with numpy array to dask array repeats = _get_repeats(target_geo_def, data) d_arr = self._expand_reduce(data.data, repeats) new_data = xr.DataArray(d_arr, dims=data.dims) return _update_resampled_coords(data, new_data, target_geo_def)
[docs] def _ensure_dask_array(d_arr): if not isinstance(d_arr, da.Array): d_arr = da.from_array(d_arr, chunks=CHUNK_SIZE) return d_arr
[docs] def _get_repeats(target_geo_def, data): y_axis, x_axis = _get_axes(data) out_shape = target_geo_def.shape in_shape = data.shape y_repeats = out_shape[0] / float(in_shape[y_axis]) x_repeats = out_shape[1] / float(in_shape[x_axis]) repeats = {axis_idx: 1. for axis_idx in range(data.ndim) if axis_idx not in [y_axis, x_axis]} repeats[y_axis] = y_repeats repeats[x_axis] = x_repeats return repeats
[docs] def _get_axes(data): if "x" not in data.dims or "y" not in data.dims: if data.ndim not in [2, 3]: raise ValueError("Can only handle 2D or 3D arrays without dimensions.") # assume rows is the second to last axis y_axis = data.ndim - 2 x_axis = data.ndim - 1 else: y_axis = data.dims.index("y") x_axis = data.dims.index("x") return y_axis, x_axis
[docs] def _aggregate(d, y_size, x_size): """Average every 4 elements (2x2) in a 2D array.""" if d.ndim != 2: # we can't guarantee what blocks we are getting and how # it should be reshaped to do the averaging. raise ValueError("Can't aggregrate (reduce) data arrays with " "more than 2 dimensions.") if not (x_size.is_integer() and y_size.is_integer()): raise ValueError("Aggregation factors are not integers") y_size = int(y_size) x_size = int(x_size) d = _rechunk_if_nonfactor_chunks(d, y_size, x_size) new_chunks = (tuple(int(x / y_size) for x in d.chunks[0]), tuple(int(x / x_size) for x in d.chunks[1])) return da.core.map_blocks(_mean, d, y_size, x_size, meta=np.array((), dtype=d.dtype), dtype=d.dtype, chunks=new_chunks)
[docs] def _mean(data, y_size, x_size): rows, cols = data.shape new_shape = (int(rows / y_size), int(y_size), int(cols / x_size), int(x_size)) data_mean = np.nanmean(data.reshape(new_shape), axis=(1, 3)) return data_mean
[docs] def _replicate(d_arr, repeats): """Repeat data pixels by the per-axis factors specified.""" repeated_chunks = _get_replicated_chunk_sizes(d_arr, repeats) d_arr = d_arr.map_blocks(_repeat_by_factor, meta=np.array((), dtype=d_arr.dtype), dtype=d_arr.dtype, chunks=repeated_chunks) return d_arr
[docs] def _get_replicated_chunk_sizes(d_arr, repeats): repeated_chunks = [] for axis, axis_chunks in enumerate(d_arr.chunks): factor = repeats[axis] if not factor.is_integer(): raise ValueError("Expand factor must be a whole number") repeated_chunks.append(tuple(x * int(factor) for x in axis_chunks)) return tuple(repeated_chunks)
[docs] def _repeat_by_factor(data, block_info=None): if block_info is None: return data out_shape = block_info[None]["chunk-shape"] out_data = data for axis, axis_size in enumerate(out_shape): in_size = data.shape[axis] out_data = np.repeat(out_data, int(axis_size / in_size), axis=axis) return out_data
[docs] def _rechunk_if_nonfactor_chunks(dask_arr, y_size, x_size): new_chunks = list(dask_arr.chunks) for dim_idx, agg_size in enumerate([y_size, x_size]): if dask_arr.shape[dim_idx] % agg_size != 0: raise ValueError("Aggregation requires arrays with shapes divisible by the factor.") need_rechunk = _check_chunking(new_chunks, dask_arr, dim_idx, agg_size) if need_rechunk: warnings.warn( "Array chunk size is not divisible by aggregation factor. " "Re-chunking to continue native resampling.", PerformanceWarning, stacklevel=5 ) dask_arr = dask_arr.rechunk(tuple(new_chunks)) return dask_arr
[docs] def _check_chunking(new_chunks, dask_arr, dim_idx, agg_size): need_rechunk = False for chunk_size in dask_arr.chunks[dim_idx]: if chunk_size % agg_size != 0: need_rechunk = True new_dim_chunk = lcm(chunk_size, agg_size) new_chunks[dim_idx] = new_dim_chunk return need_rechunk
[docs] def get_resampler_classes(): """Get classes based on native resampler.""" return {"native": NativeResampler}