Manually set values shown in legend for continuous variable of seaborn/matplotlib scatterplot

67 Views Asked by At

Is there a way to manually set the values shown in the legend of a seaborn (or matplotlib) scatterplot when the legend contains a continuous variable (hue)?

For example, in the plot below I might like to show the colors corresponding to values of [0, 1, 2, 3] rather than [1.5, 3, 4.5, 6, 7.5]

np.random.seed(123)
x = np.random.randn(500)
y = np.random.randn(500)
z = np.random.exponential(1, 500)

fig, ax = plt.subplots()
hue_norm = (0, 3)
sns.scatterplot(
    x=x,
    y=y,
    hue=z,
    hue_norm=hue_norm,
    palette='coolwarm',
)

ax.grid()
ax.set(xlabel="x", ylabel="y")
ax.legend(title="z")
sns.despine()

enter image description here

3

There are 3 best solutions below

4
JohanC On BEST ANSWER

Seaborn creates its scatterplot a bit different than matplotlib. That way, the scatterplot can be customized in more ways. For the legend, Seaborn 0.13 employs custom Line2D elements (older Seaborn versions use PathCollections).

The following approach:

  • replaces Seaborn's hue_norm=(0, 3) with an equivalent matplotlib norm
  • creates dummy Line2D elements to serve as legend handles
  • copies all properties (size, edgecolor, ...) of the legend handle created by Seaborn
  • then changes the marker color depending on the norm and colormap

The approach might need some tweaks if your scatterplot differs. The code has been tested with Matplotlib 3.8.3 and Seaborn 0.13.2 (and 0.12.2).

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from matplotlib.lines import Line2D

np.random.seed(123)
x = np.random.randn(500)
y = np.random.randn(500)
z = np.random.exponential(1, 500)

fig, ax = plt.subplots()
hue_norm = plt.Normalize(vmin=0, vmax=3)
sns.scatterplot(x=x, y=y, hue=z, hue_norm=hue_norm, palette='coolwarm', ax=ax)

legend_keys = [0, 1, 2, 3]
handles = [Line2D([], []) for _ in legend_keys]
cmap = plt.get_cmap('coolwarm')
for h, key in zip(handles, legend_keys):
    if type(ax.legend_.legend_handles[0]) == Line2D:
        h.update_from(ax.legend_.legend_handles[0])
    else:
        h.set_linestyle('')
        h.set_marker('o')
        h.set_markeredgecolor(ax.legend_.legend_handles[0].get_edgecolor())
        h.set_markeredgewidth(ax.legend_.legend_handles[0].get_linewidth())
    h.set_markerfacecolor(cmap(hue_norm(key)))
    h.set_label(f'{key}')
ax.legend(handles=handles, title='z')
sns.despine()
plt.show()

seaborn scatterplot with custom legend

1
P M On

What you are looking for is

plt.legend(*scatter.legend_elements(num=[1, 2, 3, 4]))

Here is my full code (I used just matplotlib)

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(123)
x = np.random.randn(500)
y = np.random.randn(500)
z = np.random.exponential(1, 500)

fig, ax = plt.subplots()
scatter = ax.scatter(x=x, y=y, c=z)

ax.grid()
ax.set(xlabel="x", ylabel="y")
ax.legend(*scatter.legend_elements(num=[1, 2, 3, 4]), title="z")
plt.tight_layout()
plt.show()
1
James On

If you want to modify the legend after creating the graphic, you can capture the Legend object as a variable and then iterate through the text elements it contains. The code below replaces the text with the enumerated order.

leg = ax.legend(title="z")

...

for i, txt_obj in enumerate(leg.get_texts()):
    txt_obj.set_text(i)