Source code for satpy.resample.bucket

"""Bucket resamplers."""

from logging import getLogger

import dask.array as da
import numpy as np
import xarray as xr
from pyresample.resampler import BaseResampler as PRBaseResampler

from satpy.resample.base import _update_resampled_coords
from satpy.utils import get_legacy_chunk_size

LOG = getLogger(__name__)

CHUNK_SIZE = get_legacy_chunk_size()


[docs] class BucketResamplerBase(PRBaseResampler): """Base class for bucket resampling which implements averaging."""
[docs] def __init__(self, source_geo_def, target_geo_def): """Initialize bucket resampler.""" super(BucketResamplerBase, self).__init__(source_geo_def, target_geo_def) self.resampler = None
[docs] def precompute(self, **kwargs): """Create X and Y indices and store them for later use.""" from pyresample import bucket LOG.debug("Initializing bucket resampler.") source_lons, source_lats = self.source_geo_def.get_lonlats( chunks=CHUNK_SIZE) self.resampler = bucket.BucketResampler(self.target_geo_def, source_lons, source_lats)
[docs] def compute(self, data, **kwargs): """Call the resampling.""" raise NotImplementedError("Use the sub-classes")
[docs] def resample(self, data, **kwargs): # noqa: D417 """Resample `data` by calling `precompute` and `compute` methods. Args: data (xarray.DataArray): Data to be resampled Returns (xarray.DataArray): Data resampled to the target area """ self.precompute(**kwargs) attrs = data.attrs.copy() data_arr = data.data dims = _get_dims(data) LOG.debug("Resampling %s", str(data.attrs.get("_satpy_id", "unknown"))) result = self.compute(data_arr, **kwargs) coords, result, dims = _check_coords_results_dims(result, data, dims) self._adjust_attrs(attrs) result = xr.DataArray(result, dims=dims, coords=coords, attrs=attrs) return _update_resampled_coords(data, result, self.target_geo_def)
[docs] def _adjust_attrs(self, attrs): # Adjust some attributes if "BucketFraction" in str(self): attrs["units"] = "" attrs["calibration"] = "" attrs["standard_name"] = "area_fraction" elif "BucketCount" in str(self): attrs["units"] = "" attrs["calibration"] = "" attrs["standard_name"] = "number_of_observations"
[docs] def _get_dims(data): if data.ndim == 3 and data.dims[0] == "bands": dims = ("bands", "y", "x") # Both one and two dimensional input data results in 2D output elif data.ndim in (1, 2): dims = ("y", "x") else: dims = data.dims return dims
[docs] def _check_coords_results_dims(result, data, dims): coords = {} if "bands" in data.coords: coords["bands"] = data.coords["bands"] # Fractions are returned in a dict elif isinstance(result, dict): coords["categories"] = sorted(result.keys()) dims = ("categories", "y", "x") new_result = [] for cat in coords["categories"]: new_result.append(result[cat]) result = da.stack(new_result) if result.ndim > len(dims): result = da.squeeze(result) return coords, result, dims
[docs] class BucketAvg(BucketResamplerBase): """Class for averaging bucket resampling. Bucket resampling calculates the average of all the values that are closest to each bin and inside the target area. Parameters ---------- fill_value : (float) default: `np.nan` Fill value to mark missing/invalid values in the input data, as well as in the binned and averaged output data. skipna : (bool) default: `True` If True, skips missing values (as marked by NaN or `fill_value`) for the average calculation (similarly to Numpy's `nanmean`). Buckets containing only missing values are set to fill_value. If False, sets the bucket to fill_value if one or more missing values are present in the bucket (similarly to Numpy's `mean`). In both cases, empty buckets are set to `fill_value`. """
[docs] def compute(self, data, fill_value=np.nan, skipna=True, **kwargs): # noqa: D417 """Call the resampling. Args: data (numpy.ndarray | dask.array.Array): Data to be resampled fill_value (float | int): fill_value. Defaults to numpy.nan skipna (bool): Skip NA's. Default `True` Returns: dask.array.Array """ results = [] if data.ndim == 3: for i in range(data.shape[0]): res = self.resampler.get_average(data[i, :, :], fill_value=fill_value, skipna=skipna, **kwargs) results.append(res) else: res = self.resampler.get_average(data, fill_value=fill_value, skipna=skipna, **kwargs) results.append(res) return da.stack(results)
[docs] class BucketSum(BucketResamplerBase): """Class for bucket resampling which implements accumulation (sum). This resampler calculates the cumulative sum of all the values that are closest to each bin and inside the target area. Parameters ---------- fill_value : (float) default: `np.nan` Fill value for missing data skipna : (bool) default: `True` If True, skips NaN values for the sum calculation (similarly to Numpy's `nansum`). Buckets containing only NaN are set to zero. If False, sets the bucket to NaN if one or more NaN values are present in the bucket (similarly to Numpy's `sum`). In both cases, empty buckets are set to 0. """
[docs] def compute(self, data, skipna=True, **kwargs): """Call the resampling.""" results = [] if data.ndim == 3: for i in range(data.shape[0]): res = self.resampler.get_sum(data[i, :, :], skipna=skipna, **kwargs) results.append(res) else: res = self.resampler.get_sum(data, skipna=skipna, **kwargs) results.append(res) return da.stack(results)
[docs] class BucketCount(BucketResamplerBase): """Class for bucket resampling which implements hit-counting. This resampler calculates the number of occurences of the input data closest to each bin and inside the target area. """
[docs] def compute(self, data, **kwargs): """Call the resampling.""" results = [] if data.ndim == 3: for _i in range(data.shape[0]): res = self.resampler.get_count() results.append(res) else: res = self.resampler.get_count() results.append(res) return da.stack(results)
[docs] class BucketFraction(BucketResamplerBase): """Class for bucket resampling to compute category fractions. This resampler calculates the fraction of occurences of the input data per category. """
[docs] def compute(self, data, fill_value=np.nan, categories=None, **kwargs): """Call the resampling.""" if data.ndim > 2: raise ValueError("BucketFraction not implemented for 3D datasets") result = self.resampler.get_fractions(data, categories=categories, fill_value=fill_value) return result
[docs] def get_resampler_classes(): """Get bucket resampler classes.""" return { "bucket_avg": BucketAvg, "bucket_sum": BucketSum, "bucket_count": BucketCount, "bucket_fraction": BucketFraction }