Calculating zonal means#

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

In this example we will compute the zonal-mean 2m air temperature of the ICON native grid. We retrieve the 2m-air-temperature from the EERIE control simulation.

ds = xr.open_dataset(
    "https://eerie.cloud.dkrz.de/datasets/icon-esm-er.eerie-control-1950.v20240618.atmos.native.2d_daily_mean/kerchunk",
    engine="zarr",
    chunks={},
).squeeze()
ds
<xarray.Dataset> Size: 16TB
Dimensions:             (time: 36525, ncells: 5242880)
Coordinates:
  * time                (time) datetime64[ns] 292kB 1991-01-01T23:59:59 ... 2...
    cell_sea_land_mask  (ncells) int32 21MB dask.array<chunksize=(5242880,), meta=np.ndarray>
    height              float64 8B 2.0
    height_2            float64 8B 10.0
    height_3            float64 8B 90.0
    lat                 (ncells) float64 42MB dask.array<chunksize=(5242880,), meta=np.ndarray>
    lon                 (ncells) float64 42MB dask.array<chunksize=(5242880,), meta=np.ndarray>
Dimensions without coordinates: ncells
Data variables: (12/21)
    clt                 (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
    evspsbl             (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
    hfls                (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
    hfss                (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
    hur                 (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
    hus2m               (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
    ...                  ...
    rsus                (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
    sfcwind             (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
    tas                 (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
    ts                  (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
    uas                 (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
    vas                 (time, ncells) float32 766GB dask.array<chunksize=(1, 5242880), meta=np.ndarray>
Attributes: (12/31)
    Conventions:           CF-1.7 CMIP-6.2
    activity_id:           EERIE
    data_specs_version:    01.00.32
    forcing_index:         1
    initialization_index:  1
    license:               EERIE model data produced by MPI-M is licensed und...
    ...                    ...
    parent_activity_id:    EERIE
    sub_experiment_id:     none
    experiment:            coupled control with fixed 1950's forcing (HighRes...
    source:                ICON-ESM-ER (2023): \naerosol: none, prescribed MA...
    institution:           Max Planck Institute for Meteorology, Hamburg 2014...
    sub_experiment:        none

The general idea#

behind the zonal average is to calculate a weighted average of values in a certain latitude bin. Therefore, in a first step, we count how many cells are in given equidistant latitude bins.

hist_opts = dict(bins=128, range=(-90, 90))
cells_per_bin, lat_bins = np.histogram(ds.lat, **hist_opts)

Now comes the trick! In a next step, we will repeat the histogram but account a weight to each cell. Usually, histogram will weight each data point with one, i.e. it will count the values in a certain bin. Here, we will weight each data point with the 2m-temperature. Thereby, we will compute the cumulative sum of all temperatures in a given latitude bin.

Tip

The np.histogram function is more efficient when passing a range and a number of bins. This is because, when constructing the bins internally, the function can assume equidistant bin sizes. This is not the case when passing a sequence of bins directly.

varsum_per_bin, _ = np.histogram(
    ds.lat, weights=ds.tas.isel(time=1), **hist_opts
)

The zonal mean can now be computed by dividing the cumulative values of the temperature with the number of cells in each bin.

zonal_mean = varsum_per_bin / cells_per_bin

We can check our result by plotting the zonal mean as a function of the latitdue bins. While doing so, we will scale the bins by their area so that the visual appearance of each latitude bin represents their actual proportion.

fig, ax = plt.subplots()
ax.plot(0.5 * (lat_bins[1:] + lat_bins[:-1]), zonal_mean)
ax.set_ylabel("tas / K")
ax.set_ylim(270, 305)

# Scale the x-axis to account for differences in area with latitude.
ax.set_xscale("function", functions=(lambda d: np.sin(np.deg2rad(d)), lambda d: np.arcsin(np.deg2rad(d))))
ax.set_xlim(-80, 80)
ax.set_xlabel("latitude")
Text(0.5, 0, 'latitude')
../_images/af57c1dbcba29cc1fff0764bb21bae4da16a64a477f9792c8e2e6573a3db471e.png

Note

The concept of using histograms to compute zonal means can also be generalized across other dimensions, i.e., a meridional mean or to compute distributions in temperature or humidity space.

Multi-dimensional input#

Until now, we used data at a single time step to illustrate the general idea of calculating zonal means by using histograms. However, most real-world data has several other dimensions like time or height. A straight-forward way to calculate zonal means for this kind of data is to unravel it, i.e., to get rid off every dimensions. This approach, however, will loose all information of the thrown away axes; data at different heights or times will be mixed. Fortunately, there are alternatives that allow us to calculate zonal means while maintaining the dimensional structure of our dataset. We achieve this by using xr.apply_ufunc which lifts a function (in this case _compute_varsum) from (numpy) arrays to (xarray) DataArrays. This lifting into the world of DataArrays involves describing the dimensions, shapes and data types which the function cares about. Afterwards, xarray applies the usual looping and broadcasting rules over the dimensions the functions does not care about. Any necessary looping may then be carried out in parallel (e.g. using dask).

def calc_zonal_mean(variable, **hist_opts):
    """Compute a zonal-mean (along `clat`) for multi-dimensional input."""
    counts_per_bin, bin_edges = np.histogram(variable.lat, **hist_opts)

    def _compute_varsum(var, **kwargs):
        """Helper function to compute histogram for a single timestep."""
        varsum_per_bin, _ = np.histogram(variable.lat, weights=var, **kwargs)
        return varsum_per_bin

    # For more information see:
    # https://docs.xarray.dev/en/stable/generated/xarray.apply_ufunc.html
    varsum = xr.apply_ufunc(
        _compute_varsum,  # function to map
        variable,  # variables to loop over
        kwargs=hist_opts,  # keyword arguments passed to the function
        input_core_dims=[["ncells"]],  # dimensions that should not be kept
        # Description of the output dataset
        dask="parallelized",
        vectorize=True,
        output_core_dims=[("lat_bins",)],
        dask_gufunc_kwargs={
            "output_sizes": {"lat_bins": hist_opts["bins"]},
        },
        output_dtypes=["f8"],
    )

    return varsum / counts_per_bin, bin_edges

Using this function we can calculate the zonal means along the time dimension.

zonal_means, lat_bins = calc_zonal_mean(
    ds.tas.isel(time=slice(0, 365, 30)), **hist_opts
)

We can now either plot the zonal means for individual timesteps or the whole dataset.

Warning

The function xr.apply_ufunc() returns a lazy data array, meaning that the data is only loaded when requested. Computing a histogram for the entire time period thus requires loading all the data.

fig, ax = plt.subplots()
for zonal_mean in zonal_means:
    ax.plot(0.5 * (lat_bins[1:] + lat_bins[:-1]), zonal_mean)

ax.plot(0.5 * (lat_bins[1:] + lat_bins[:-1]), zonal_means.mean("time"), lw=3, c="k")
ax.set_ylabel("tas / K")
ax.set_ylim(270, 305)

# Scale the x-axis to account for differences in area with latitude.
ax.set_xscale("function", functions=(lambda d: np.sin(np.deg2rad(d)), lambda d: np.arcsin(np.deg2rad(d))))
ax.set_xlim(-80, 80)
ax.set_xlabel("latitude")
Text(0.5, 0, 'latitude')
../_images/7feaeabb6b0491f8ce23bd518c5cd1253be07e6d383b220415d6c28d9a84636c.png