#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2017-2019 Satpy developers
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Utilities for various satpy tests."""
import datetime as dt
import os
from contextlib import contextmanager
from typing import Any
from unittest import mock
import dask.array as da
import numpy as np
from pyresample import create_area_def
from pyresample.geometry import BaseDefinition, SwathDefinition
from xarray import DataArray
from satpy import Scene
from satpy.composites import GenericCompositor, IncompatibleAreas
from satpy.dataset import DataID, DataQuery
from satpy.dataset.dataid import default_id_keys_config, minimal_default_keys_config
from satpy.modifiers import ModifierBase
from satpy.readers.file_handlers import BaseFileHandler
FAKE_FILEHANDLER_START = dt.datetime(2020, 1, 1, 0, 0, 0)
FAKE_FILEHANDLER_END = dt.datetime(2020, 1, 1, 1, 0, 0)
RANDOM_GEN = np.random.default_rng()
[docs]
def make_dataid(**items):
"""Make a DataID with default keys."""
return DataID(default_id_keys_config, **items)
[docs]
def make_cid(**items):
"""Make a DataID with a minimal set of keys to id composites."""
return DataID(minimal_default_keys_config, **items)
[docs]
def make_dsq(**items):
"""Make a dataset query."""
return DataQuery(**items)
[docs]
def spy_decorator(method_to_decorate):
"""Fancy decorator to wrap an object while still calling it.
See https://stackoverflow.com/a/41599695/433202
"""
tmp_mock = mock.MagicMock()
def wrapper(self, *args, **kwargs):
tmp_mock(*args, **kwargs)
return method_to_decorate(self, *args, **kwargs)
wrapper.mock = tmp_mock
return wrapper
[docs]
def convert_file_content_to_data_array(file_content, attrs=tuple(),
dims=("z", "y", "x")):
"""Help old reader tests that still use numpy arrays.
A lot of old reader tests still use numpy arrays and depend on the
"var_name/attr/attr_name" convention established before Satpy used xarray
and dask. While these conventions are still used and should be supported,
readers need to use xarray DataArrays instead.
If possible, new tests should be based on pure DataArray objects instead
of the "var_name/attr/attr_name" style syntax provided by the utility
file handlers.
Args:
file_content (dict): Dictionary of string file keys to fake file data.
attrs (iterable): Series of attributes to copy to DataArray object from
file content dictionary. Defaults to no attributes.
dims (iterable): Dimension names to use for resulting DataArrays.
The second to last dimension is used for 1D arrays, so for
dims of ``('z', 'y', 'x')`` this would use ``'y'``. Otherwise, the
dimensions are used starting with the last, so 2D arrays are
``('y', 'x')``
Dimensions are used in reverse order so the last dimension
specified is used as the only dimension for 1D arrays and the
last dimension for other arrays.
"""
for key, val in file_content.items():
da_attrs = {}
for a in attrs:
if key + "/attr/" + a in file_content:
da_attrs[a] = file_content[key + "/attr/" + a]
if isinstance(val, np.ndarray):
val = da.from_array(val, chunks=4096)
if val.ndim == 1:
da_dims = dims[-2]
elif val.ndim > 1:
da_dims = tuple(dims[-val.ndim:])
else:
da_dims = None
file_content[key] = DataArray(val, dims=da_dims, attrs=da_attrs)
[docs]
def _filter_datasets(all_ds, names_or_ids):
"""Help filtering DataIDs by name or DataQuery."""
# DataID will match a str to the name
# need to separate them out
str_filter = [ds_name for ds_name in names_or_ids if isinstance(ds_name, str)]
id_filter = [ds_id for ds_id in names_or_ids if not isinstance(ds_id, str)]
for ds_id in all_ds:
if ds_id in id_filter or ds_id["name"] in str_filter:
yield ds_id
[docs]
def _swath_def_of_data_arrays(rows, cols):
return SwathDefinition(
DataArray(da.zeros((rows, cols)), dims=("y", "x")),
DataArray(da.zeros((rows, cols)), dims=("y", "x")),
)
[docs]
class FakeModifier(ModifierBase):
"""Act as a modifier that performs different modifications."""
[docs]
def _handle_res_change(self, datasets, info):
# assume this is used on the 500m version of ds5
info["resolution"] = 250
rep_data_arr = datasets[0]
y_size = rep_data_arr.sizes["y"]
x_size = rep_data_arr.sizes["x"]
data = da.zeros((y_size * 2, x_size * 2))
if isinstance(rep_data_arr.attrs["area"], SwathDefinition):
area = _swath_def_of_data_arrays(y_size * 2, x_size * 2)
info["area"] = area
else:
raise NotImplementedError("'res_change' modifier can't handle "
"AreaDefinition changes yet.")
return data
def __call__(self, datasets, optional_datasets=None, **kwargs):
"""Modify provided data depending on the modifier name and input data."""
if self.attrs["optional_prerequisites"]:
for opt_dep in self.attrs["optional_prerequisites"]:
opt_dep_name = opt_dep if isinstance(opt_dep, str) else opt_dep.get("name", "")
if "NOPE" in opt_dep_name or "fail" in opt_dep_name:
continue
assert optional_datasets is not None
assert len(optional_datasets)
resolution = datasets[0].attrs.get("resolution")
mod_name = self.attrs["modifiers"][-1]
data = datasets[0].data
i = datasets[0].attrs.copy()
if mod_name == "res_change" and resolution is not None:
data = self._handle_res_change(datasets, i)
elif "incomp_areas" in mod_name:
raise IncompatibleAreas(
"Test modifier 'incomp_areas' always raises IncompatibleAreas")
self.apply_modifier_info(datasets[0].attrs, i)
return DataArray(data,
dims=datasets[0].dims,
# coords=datasets[0].coords,
attrs=i)
[docs]
class FakeCompositor(GenericCompositor):
"""Act as a compositor that produces fake RGB data."""
def __call__(self, projectables, nonprojectables=None, **kwargs):
"""Produce test compositor data depending on modifiers and input data provided."""
if projectables:
projectables = self.match_data_arrays(projectables)
if nonprojectables:
self.match_data_arrays(nonprojectables)
info = self.attrs.copy()
if self.attrs["name"] in ("comp14", "comp26"):
# used as a test when composites update the dataset id with
# information from prereqs
info["resolution"] = 555
if self.attrs["name"] in ("comp24", "comp25"):
# other composites that copy the resolution from inputs
info["resolution"] = projectables[0].attrs.get("resolution")
if len(projectables) != len(self.attrs["prerequisites"]):
raise ValueError("Not enough prerequisite datasets passed")
info.update(kwargs)
if projectables:
info["area"] = projectables[0].attrs["area"]
dim_sizes = projectables[0].sizes
else:
# static_image
dim_sizes = {"y": 4, "x": 5}
return DataArray(data=da.zeros((dim_sizes["y"], dim_sizes["x"], 3)),
attrs=info,
dims=["y", "x", "bands"],
coords={"bands": ["R", "G", "B"]})
[docs]
class FakeFileHandler(BaseFileHandler):
"""Fake file handler to be used by test readers."""
def __init__(self, filename, filename_info, filetype_info, **kwargs):
"""Initialize file handler and accept all keyword arguments."""
self.kwargs = kwargs
super().__init__(filename, filename_info, filetype_info)
@property
def start_time(self):
"""Get static start time datetime object."""
return FAKE_FILEHANDLER_START
@property
def end_time(self):
"""Get static end time datetime object."""
return FAKE_FILEHANDLER_END
@property
def sensor_names(self):
"""Get sensor name from filetype configuration."""
sensor = self.filetype_info.get("sensor", "fake_sensor")
return {sensor}
[docs]
def get_dataset(self, data_id: DataID, ds_info: dict):
"""Get fake DataArray for testing."""
if data_id["name"] == "ds9_fail_load":
raise KeyError("Can't load '{}' because it is supposed to "
"fail.".format(data_id["name"]))
attrs = data_id.to_dict()
attrs.update(ds_info)
attrs["sensor"] = self.filetype_info.get("sensor", "fake_sensor")
attrs["platform_name"] = "fake_platform"
attrs["start_time"] = self.start_time
attrs["end_time"] = self.end_time
res = attrs.get("resolution", 250)
rows = cols = {
250: 20,
500: 10,
1000: 5,
}.get(res, 5)
return DataArray(data=da.zeros((rows, cols)),
attrs=attrs,
dims=["y", "x"])
[docs]
def available_datasets(self, configured_datasets=None):
"""Report YAML datasets available unless 'not_available' is specified during creation."""
not_available_names = self.kwargs.get("not_available", [])
for is_avail, ds_info in (configured_datasets or []):
if is_avail is not None:
# some other file handler said it has this dataset
# we don't know any more information than the previous
# file handler so let's yield early
yield is_avail, ds_info
continue
ft_matches = self.file_type_matches(ds_info["file_type"])
if not ft_matches:
yield None, ds_info
continue
# mimic what happens when a reader "knows" about one variable
# but the files loaded don't have that variable
is_avail = ds_info["name"] not in not_available_names
yield is_avail, ds_info
[docs]
class CustomScheduler(object):
"""Scheduler raising an exception if data are computed too many times."""
def __init__(self, max_computes=1):
"""Set starting and maximum compute counts."""
self.max_computes = max_computes
self.total_computes = 0
def __call__(self, dsk, keys, **kwargs):
"""Compute dask task and keep track of number of times we do so."""
import dask
self.total_computes += 1
if self.total_computes > self.max_computes:
raise RuntimeError("Too many dask computations were scheduled: "
"{}".format(self.total_computes))
return dask.get(dsk, keys, **kwargs)
[docs]
@contextmanager
def assert_maximum_dask_computes(max_computes=1):
"""Context manager to make sure dask computations are not executed more than ``max_computes`` times."""
import dask
with dask.config.set(scheduler=CustomScheduler(max_computes=max_computes)) as new_config:
yield new_config
[docs]
def make_fake_scene(content_dict, daskify=False, area=True,
common_attrs=None):
"""Create a fake Scene.
Create a fake Scene object from fake data. Data are provided in
the ``content_dict`` argument. In ``content_dict``, keys should be
strings or DataID, and values may be either numpy.ndarray
or xarray.DataArray, in either case with exactly two dimensions.
The function will convert each of the numpy.ndarray objects into
an xarray.DataArray and assign those as datasets to a Scene object.
A fake AreaDefinition will be assigned for each array, unless disabled
by passing ``area=False``. When areas are automatically generated,
arrays with the same shape will get the same area.
This function is exclusively intended for testing purposes.
If regular ndarrays are passed and the keyword argument daskify is
True, DataArrays will be created as dask arrays. If False (default),
regular DataArrays will be created. When the user passes xarray.DataArray
objects then this flag has no effect.
Args:
content_dict (Mapping): Mapping where keys correspond to objects
accepted by ``Scene.__setitem__``, i.e. strings or DataID,
and values may be either ``numpy.ndarray`` or
``xarray.DataArray``.
daskify (bool): optional, to use dask when converting
``numpy.ndarray`` to ``xarray.DataArray``. No effect when the
values in ``content_dict`` are already ``xarray.DataArray``.
area (bool or BaseDefinition): Can be ``True``, ``False``, or an
instance of ``pyresample.geometry.BaseDefinition`` such as
``AreaDefinition`` or ``SwathDefinition``. If ``True``, which is
the default, automatically generate areas with the name "test-area".
If ``False``, values will not have assigned areas. If an instance
of ``pyresample.geometry.BaseDefinition``, those instances will be
used for all generated fake datasets. Warning: Passing an area as
a string (``area="germ"``) is not supported.
common_attrs (Mapping): optional, additional attributes that will
be added to every dataset in the scene.
Returns:
Scene object with datasets corresponding to content_dict.
"""
if common_attrs is None:
common_attrs = {}
sc = Scene()
for (did, arr) in content_dict.items():
extra_attrs = common_attrs.copy()
if area:
extra_attrs["area"] = _get_fake_scene_area(arr, area)
sc[did] = _get_did_for_fake_scene(area, arr, extra_attrs, daskify)
return sc
[docs]
def _get_fake_scene_area(arr, area):
"""Get area for fake scene. Helper for make_fake_scene."""
if isinstance(area, BaseDefinition):
return area
return create_area_def(
"test-area",
{"proj": "eqc", "lat_ts": 0, "lat_0": 0, "lon_0": 0,
"x_0": 0, "y_0": 0, "ellps": "sphere", "units": "m",
"no_defs": None, "type": "crs"},
units="m",
shape=arr.shape,
resolution=1000,
center=(0, 0))
[docs]
def _get_did_for_fake_scene(area, arr, extra_attrs, daskify):
"""Add instance to fake scene. Helper for make_fake_scene."""
from satpy.resample import add_crs_xy_coords
if isinstance(arr, DataArray):
new = arr.copy() # don't change attributes of input
new.attrs.update(extra_attrs)
else:
if daskify:
arr = da.from_array(arr)
new = DataArray(
arr,
dims=("y", "x"),
attrs=extra_attrs)
if area:
new = add_crs_xy_coords(new, extra_attrs["area"])
return new
[docs]
def assert_attrs_equal(attrs, attrs_exp, tolerance=0):
"""Test that attributes are equal.
Walks dictionary recursively. Numerical attributes are compared with
the given relative tolerance.
"""
keys_diff = set(attrs).difference(set(attrs_exp))
assert not keys_diff, "Different set of keys: {}".format(keys_diff)
for key in attrs_exp:
err_msg = "Attribute {} does not match expectation".format(key)
if isinstance(attrs[key], dict):
assert_attrs_equal(attrs[key], attrs_exp[key], tolerance)
else:
try:
np.testing.assert_allclose(
attrs[key],
attrs_exp[key],
rtol=tolerance,
err_msg=err_msg
)
except TypeError:
assert attrs[key] == attrs_exp[key], err_msg
[docs]
def assert_dict_array_equality(d1, d2):
"""Check that dicts containing arrays are equal."""
assert set(d1.keys()) == set(d2.keys())
for key, val1 in d1.items():
val2 = d2[key]
compare_func = _compare_numpy_array if isinstance(val1, np.ndarray) else _compare_nonarray
compare_func(val1, val2)
[docs]
def _compare_numpy_array(val1: np.ndarray, val2: np.ndarray) -> None:
np.testing.assert_array_equal(val1, val2)
assert val1.dtype == val2.dtype
[docs]
def _compare_nonarray(val1: Any, val2: Any) -> None:
assert val1 == val2
if isinstance(val1, (np.floating, np.integer, np.bool_)):
assert isinstance(val2, np.generic)
assert val1.dtype == val2.dtype
[docs]
def xfail_skyfield_unstable_numpy2():
"""Determine if skyfield-based tests should be xfail in the unstable numpy 2.x environment."""
try:
import skyfield
# known numpy incompatibility:
from skyfield import timelib # noqa
except ImportError:
skyfield = None
import os
is_unstable_ci = os.environ.get("UNSTABLE", "0") in ("1", "true")
is_np2 = np.__version__.startswith("2.")
return skyfield is None and is_np2 and is_unstable_ci
[docs]
def xfail_h5py_unstable_numpy2():
"""Determine if h5py-based tests should be xfail in the unstable numpy 2.x environment."""
from packaging import version
try:
import h5py
is_broken_h5py = version.parse(h5py.__version__) <= version.parse("3.10.0")
except ImportError:
is_broken_h5py = True
import os
is_unstable_ci = os.environ.get("UNSTABLE", "0") in ("1", "true")
is_np2 = np.__version__.startswith("2.")
return is_broken_h5py and is_np2 and is_unstable_ci
[docs]
def skip_numba_unstable_if_missing():
"""Determine if numba-based tests should be skipped during unstable CI tests.
If numba fails to import it could be because numba is not compatible with
a newer version of numpy. This is very likely to happen in the
unstable/experimental CI environment. This function returns ``True`` if
numba-based tests should be skipped if ``numba`` could not
be imported *and* we're in the unstable environment. We determine if we're
in this CI environment by looking for the ``UNSTABLE="1"``
environment variable.
"""
try:
import numba
except ImportError:
numba = None
return numba is None and os.environ.get("UNSTABLE", "0") in ("1", "true")