"""Module for LSA criteria aggregation."""
from __future__ import annotations
import xarray as xr
__all__ = ["aggregate"]
def _agg_weights(ds: xr.Dataset, variables: list[str], weights: list[int | float] | None = None) -> xr.DataArray:
"""Returns weights as an xarray.DataArray with given variables as dimensions."""
if weights is None:
weights = [1.0] * len(variables)
if len(variables) != len(weights):
raise ValueError("Length of 'weights' must match length of 'variables'.")
return xr.DataArray(weights, dims=["variable"], coords={"variable": variables})
def _add_agg_attrs(
da: xr.DataArray | xr.Dataset, method: str, variables: list[str], weights: list[int | float] | None = None
) -> xr.DataArray | xr.Dataset:
"""Add aggregation method attributes to the DataArray or Dataset."""
if method in ["wmean", "wgmean"] and weights is not None:
desc_vars = ", ".join([f"{v} ({w})" for v, w in zip(variables, weights, strict=False)])
else:
desc_vars = ", ".join(variables)
attrs = {
"median": {
"method": "Median",
"description": f"Median of the variables: {desc_vars}.",
},
"mean": {
"method": "Mean",
"description": f"Arithmetic mean of the variables: {desc_vars}.",
},
"wmean": {"method": "Weighted Mean", "description": f"Weighted mean of the variables: {desc_vars}."},
"gmean": {"method": "Geometric Mean", "description": f"Geometric mean of the variables: {desc_vars}."},
"wgmean": {
"method": "Weighted Geometric Mean",
"description": f"Weighted geometric mean of the variables: {desc_vars}.",
},
"limiting_factor": {
"method": "Limiting Factor",
"description": f"Value of the limiting factor among the variables: {desc_vars}.",
},
"limiting_variable": {"method": "Limiting Factor", "description": f"Limiting variable among: {desc_vars}."},
}
names = {
"median": "median",
"mean": "mean",
"wmean": "weighted_mean",
"gmean": "geometric_mean",
"wgmean": "weighted_geometric_mean",
}
if method == "limfactor":
da["limiting_factor"].attrs.update(attrs.get("limiting_factor", {}))
da["limiting_variable"].attrs.update(attrs.get("limiting_variable", {}))
else:
da.attrs.update(attrs.get(method, {}))
da.name = names.get(method, method)
return da
[docs]
def aggregate(
ds: xr.Dataset,
method: str = "mean",
variables: list[str] | None = None,
weights: list[int | float] | None = None,
) -> xr.DataArray | xr.Dataset:
"""
Aggregate variables of an xarray.Dataset using specified methods.
Parameters
----------
ds : xr.Dataset
Input dataset containing the variables to aggregate.
method : str, optional
Aggregation method to use. Options include 'mean', 'median', 'wmean' (weighted mean),
'gmean' (geometric mean), 'wgmean' (weighted geometric mean), and 'limfactor' (limiting factor).
Default is 'mean'.
variables : list[str], optional
List of variable names to aggregate. If None, all variables in the dataset are used.
weights : list[int | float], optional
Weights for the variables when using weighted methods ('wmean', 'wgmean').
If None, equal weights are assumed.
Returns
-------
xr.DataArray | xr.Dataset
Aggregated data as an xarray.DataArray or xarray.Dataset.
"""
if method not in ["mean", "median", "wmean", "gmean", "wgmean", "limfactor"]:
raise ValueError(
f"Invalid method '{method}'. "
"Supported methods are: 'median', 'mean', 'wmean', 'gmean', 'wgmean', 'limfactor'."
)
if isinstance(variables, list):
ds = ds[variables]
elif variables is None:
variables = list(ds.data_vars)
else:
raise ValueError("'variables' must be a list of variable names or None.")
if method not in ["wmean", "wgmean"] and weights is not None:
weights = None # Ignore weights for non-weighted methods
if not isinstance(weights, list) and weights is not None:
raise ValueError("'weights' must be a list of numbers or None.")
_weights = _agg_weights(ds, variables, weights)
da = ds.to_dataarray()
if method == "median":
out = da.median(dim="variable").rename("median")
if method in ["mean", "wmean"]:
out = da.weighted(_weights).mean(dim="variable")
if method in ["gmean", "wgmean"]:
out = (da**_weights).prod(dim="variable", min_count=(len(variables))) ** (1 / _weights.sum())
if method == "limfactor":
limval = da.min(dim="variable").rename("limiting_factor")
limvar = (
xr.concat([ds[v] == limval for v in ds.data_vars], dim="variable")
.assign_coords(variable=list(ds.data_vars))
.rename("limiting_variable")
)
out = xr.merge([limval, limvar]).assign_attrs({"method": "Limiting Factor"})
return _add_agg_attrs(out, method, variables, weights)