xarray discrete scatter plot: specifying legend/colour order

1.2k Views Asked by At

Plotting a discrete xarray DataArray variable in a Dataset with xr.plot.scatter() yields a legend in which the discrete values are ordered arbitrarily, corresponding to unpredictable colour assignment to each level. Would it be possible to specify a specific colour or position for a given discrete value?

A simple reproducible example:

import xarray as xr

# get a predefined dataset
uvz = xr.tutorial.open_dataset("eraint_uvz")

# select a 2-D subset of the data
uvzr = uvz.isel(level=0, month=0, latitude=slice(150, 242),
                longitude=slice(240, 300))

# define a discrete variable based on levels of a continuous variable
uvzr['zone'] = 'A'
uvzr['zone'] = uvzr.zone.where(uvzr.u > 30, other='C')
uvzr['zone'] = uvzr.zone.where(uvzr.u > 10, other='B')

# do the plot
xr.plot.scatter(uvzr, x='longitude', y='latitude', hue='zone')

This produces the following plot

Is there a way to ensure that the legend entries are arranged 'A', 'B', 'C' from top to bottom, say? Or ensure that A is assigned to blue, and B to orange, for example?

I know I can reset the values of the matplotlib color cycler, but for that to be useful I first need to know which order the discrete values will be plotted in.

I'm using xarray v2022.3.0 on python 3.8.6. With an earlier version of xarray (I think 0.16) the levels were arranged alphabetically.

1

There are 1 best solutions below

1
On BEST ANSWER

I found an ugly workaround using xarray.Dataset.stack and xr.where(..., drop=True), in case anyone else is stuck with a similar problem.

import numpy as np   # for unique, to cycle through values
import matplotlib.pyplot as plt   # to get a legend

# instead of np.unique you could pass an iterable of your choice
# specifying the order
for value in np.unique(uvzr.zone):
    # convert to a 1-D dataframe with a co-ordinate including all
    # unique combinations of latitude-longitude values
    uvzr_stacked = uvzr.stack({'location':('longitude', 'latitude')})

    # now select only those grid points in zone value
    uvzr_stacked = uvzr_stacked.where(uvzr_stacked.zone == value,
                                      drop=True)

    # the plotting function can't see the original dims any more;
    # a new name is required, however
    uvzr_stacked['lat'] = uvzr_stacked.latitude
    uvzr_stacked['lon'] = uvzr_stacked.longitude

    # plot!
    xr.plot.scatter(uvzr_stacked, x='lon', y='lat', hue='zone',
                    add_guide=False)

plt.legend(title='zone')