Calculating zonal means#

[1]:
import intake
import matplotlib.pyplot as plt
import numpy as np

from gridlocator import merge_grid

In this example we will compute the zonal-mean air temperature. We retrieve the 2m-air-temperature from the dpp0067 NextGEMS simulation using intake-esm.

[2]:
catalog_file = "/work/ka1081/Catalogs/dyamond-nextgems.json"
col = intake.open_esm_datastore(catalog_file)
cat = col.search(
    variable_id="tas",
    project="NextGEMS",
    simulation_id="dpp0067",
)
cat_dict = cat.to_dataset_dict(cdf_kwargs={"chunks": {"time": 1}})

ds = merge_grid(list(cat_dict.values())[0])  # Include the grid information!

--> The keys in the returned dictionary of datasets are constructed as follows:
        'project.institution_id.source_id.experiment_id.simulation_id.realm.frequency.time_reduction.grid_label.level_type'
100.00% [1/1 00:00<00:00]

The general idea#

behind the zonal average is to calculate a weighted average of values in a certain latitude bin. Therefore, in a first step, we count how many cells are in given equidistant latitude bins.

[3]:
hist_opts = dict(bins=128, range=(-np.pi / 2, np.pi / 2))
cells_per_bin, lat_bins = np.histogram(ds.clat, **hist_opts)

Now comes the trick! In a next step, we will repeat the histogram but account a weight to each cell. Usually, histogram will weight each data point with one, i.e. it will count the values in a certain bin. Here, we will weight each data point with the 2m-temperature. Thereby, we will compute the cumulative sum of all temperatures in a given latitude bin.

Tip

The np.histogram function is more efficient when passing a range and a number of bins. This is because, when constructing the bins internally, the function can assume equidistant bin sizes. This is not the case when passing a sequence of bins directly.

[4]:
varsum_per_bin, _ = np.histogram(
    ds.clat, weights=ds.tas.isel(time=1, height_2=0), **hist_opts
)

The zonal mean can now be computed by dividing the cumulative values of the temperature with the number of cells in each bin.

[5]:
zonal_mean = varsum_per_bin / cells_per_bin

We can check our result by plotting the zonal mean as a function of the latitdue bins. While doing so, we will scale the bins by their area so that the visual appearance of each latitude bin represents their actual proportion.

[6]:
fig, ax = plt.subplots()
ax.plot(0.5 * (lat_bins[1:] + lat_bins[:-1]), zonal_mean)
ax.set_ylabel("tas / K")
ax.set_ylim(270, 305)

# Scale the x-axis to account for differences in area with latitude.
ax.set_xscale("function", functions=(lambda d: np.sin(d), lambda d: np.arcsin(d)))
ax.set_xlim(np.deg2rad(-80), np.deg2rad(80))
ax.set_xlabel("latitude")
[6]:
Text(0.5, 0, 'latitude')
../../_images/Processing_playing_with_triangles_zonal_mean_11_1.png

Multi-dimensional input#

Until now, we used data at a single time step to illustrate the general idea of calculating zonal means by using histograms. However, most real-world data has several other dimensions like time or height. A straight-forward way to calculate zonal means for this kind of data is to unravel it, i.e., to get rid off every dimensions. This approach, however, will loose all information of the thrown away axes; data at different heights or times will be mixed. Fortunately, there are alternatives that allow us to calculate zonal means while maintaining the dimensional structure of our dataset. We achieve this by using xr.apply_ufunc which lifts a function (in this case _compute_varsum) from (numpy) arrays to (xarray) DataArrays. This lifting into the world of DataArrays involves describing the dimensions, shapes and data types which the function cares about. Afterwards, xarray applies the usual looping and broadcasting rules over the dimensions the functions does not care about. Any necessary looping may then be carried out in parallel (e.g. using dask).

[10]:
import xarray as xr


def calc_zonal_mean(variable, **kwargs):
    """Compute a zonal-mean (along `clat`) for multi-dimensional input."""
    counts_per_bin, bin_edges = np.histogram(variable.clat, **hist_opts)

    def _compute_varsum(var, **kwargs):
        """Helper function to compute histogram for a single timestep."""
        varsum_per_bin, _ = np.histogram(variable.clat, weights=var, **kwargs)
        return varsum_per_bin

    # For more information see:
    # https://docs.xarray.dev/en/stable/generated/xarray.apply_ufunc.html
    varsum = xr.apply_ufunc(
        _compute_varsum,  # function to map
        variable,  # variables to loop over
        kwargs=hist_opts,  # keyword arguments passed to the function
        input_core_dims=[["cell"]],  # dimensions that should not be kept
        # Description of the output dataset
        dask="parallelized",
        vectorize=True,
        output_core_dims=[("lat",)],
        dask_gufunc_kwargs={
            "output_sizes": {"lat": hist_opts["bins"]},
        },
        output_dtypes=["f8"],
    )

    return varsum / counts_per_bin, bin_edges

Using this function we can calculate the zonal means along the time dimension.

[11]:
zonal_means, lat_bins = calc_zonal_mean(
    ds.tas.isel(time=slice(24, None, 48), height_2=0), **hist_opts
)

We can now either plot the zonal means for individual timesteps or the whole dataset.

[12]:
fig, ax = plt.subplots()
for zonal_mean in zonal_means:
    ax.plot(0.5 * (lat_bins[1:] + lat_bins[:-1]), zonal_mean)
ax.plot(0.5 * (lat_bins[1:] + lat_bins[:-1]), zonal_means.mean("time"), lw=3, c="k")
ax.set_ylabel("tas / K")
ax.set_ylim(270, 305)

# Scale the x-axis to account for differences in area with latitude.
ax.set_xscale("function", functions=(lambda d: np.sin(d), lambda d: np.arcsin(d)))
ax.set_xlim(np.deg2rad(-80), np.deg2rad(80))
ax.set_xlabel("latitude")
[12]:
Text(0.5, 0, 'latitude')
../../_images/Processing_playing_with_triangles_zonal_mean_17_1.png

Tip

The concept of using histograms to compute zonal means can also be generalized across other dimensions, i.e., a meridional mean or to compute distributions in temperature or humidity space.