How to add legend when using twinx

174 Views Asked by At

I am trying to plot 2 values with same x axis values (time) but different y axis values. I am trying to add legend to the plot, however, I keep getting 2 legends for the same variable instead of one legend for each variable.

This is the code i have used.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime, timedelta

# Assuming you have the time series data in numpy arrays named x_values, y_values, and fire_counts
df=pd.read_csv('FireCountPenchMar2012.csv')
print(df)

xaxis=np.arange(0,248,1)
yaxis=hdwi
yaxis2=df.fire_count
print(yaxis2)

# Step 1: Create the dates for the x-axis based on the starting date (1 March 00:00 UTC)
start_date = datetime(2023, 3, 1, 0, 0, 0)
dates = [start_date + timedelta(hours=3*i) for i in range(len(xaxis))]

plt.figure(figsize=(15,10))

# Step 2: Plot the first time series with dates on the x-axis
plt.plot(dates, yaxis, 'o-g', label='HDWI')

# Step 3: Format the first y-axis and add axis labels and a title
plt.ylabel('HDW')
plt.title('HDWI v/s Fire Counts')

# Step 4: Create a second y-axis for the 'fire_counts' variable
ax2 = plt.gca().twinx()
ax2.scatter(dates, yaxis2, color='red', label='Fire Counts')
ax2.set_ylabel('Fire Counts')

# Step 5: Show the legend for both lines (y_values and fire_counts)
lines, labels = plt.gca().get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()

# Combine the handles and labels for both legends
all_lines = lines + lines2
all_labels = labels + labels2

# Display the combined legend
plt.gca().legend(all_lines, all_labels)

# Step 6: Format the x-axis to display dates at regular intervals (e.g., every 2 days)
date_format = mdates.DateFormatter('%b %d')
plt.gca().xaxis.set_major_locator(mdates.DayLocator(interval=2))
plt.gca().xaxis.set_major_formatter(date_format)

# Step 7: Rotate the x-axis date labels for better readability
plt.xticks(rotation=45)

# Step 8: Adjust the layout and display the plot
plt.grid()
plt.tight_layout()
plt.show()here

The legend in this figure should highlight the issue

The figure as obtained in the output with repeating legend

2

There are 2 best solutions below

0
gboffi On BEST ANSWER

enter image description here

For a detailed explanation, please refer to this answer.

from matplotlib.pyplot import show, subplots

# make the twin Axes
fig, ax0 = subplots()
ax1 = ax0.twinx()

# plot two lines in the 1st Axes, label them
ax0.plot((0, 1, 2, 3, 4), color='g', label='green')
ax0.plot((4, 3, 2, 1, 0), color='k', label='black')

# make a scatter in the 2nd Axes, label it
ax1.scatter(( 1,  2,  3),
            (41, 37, 23), color='r', label='red')

# collect handles and labels in a list of lists [[h0, l0], [h1, l1]]
handles_labels = [ax.get_legend_handles_labels() for ax in (ax0, ax1)]

# but Axes.legend needs h=h0+h1 and l=l0+l1, so
# ① transpose hl → [[h0, h1, [l0, l1]]
hl_transposed = zip(*handles_labels)
# ② sum the sub-lists in the list of handles and in the list of labels
handles, labels = [sum(handles_or_labels, [])
                      for handles_or_labels in hl_transposed]

ax0.legend(handles, labels)
show()
2
Vora Arshit On

The code looks mostly correct, but there's a minor issue in the way you are creating the legend. Instead of combining the handles and labels for both legends, you should create the legend using the handles and labels from the individual plots directly.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime, timedelta
import pandas as pd

# Assuming you have the time series data in numpy arrays named x_values, y_values, and fire_counts
df = pd.read_csv('FireCountPenchMar2012.csv')
print(df)

xaxis = np.arange(0, 248, 1)
yaxis = hdwi
yaxis2 = df.fire_count
print(yaxis2)

# Step 1: Create the dates for the x-axis based on the starting date (1 March 00:00 UTC)
start_date = datetime(2023, 3, 1, 0, 0, 0)
dates = [start_date + timedelta(hours=3 * i) for i in range(len(xaxis))]

plt.figure(figsize=(15, 10))

# Step 2: Plot the first time series with dates on the x-axis
line1, = plt.plot(dates, yaxis, 'o-g', label='HDWI')

# Step 3: Format the first y-axis and add axis labels and a title
plt.ylabel('HDW')
plt.title('HDWI v/s Fire Counts')

# Step 4: Create a second y-axis for the 'fire_counts' variable
ax2 = plt.gca().twinx()
line2, = ax2.plot(dates, yaxis2, 'o-r', label='Fire Counts')
ax2.set_ylabel('Fire Counts')

# Step 5: Create legends for both lines separately
legend1 = plt.legend(handles=[line1], loc='upper left')
legend2 = plt.legend(handles=[line2], loc='upper right')

# Step 6: Add the legends to the plot
plt.gca().add_artist(legend1)
plt.gca().add_artist(legend2)

# Step 7: Format the x-axis to display dates at regular intervals (e.g., every 2 days)
date_format = mdates.DateFormatter('%b %d')
plt.gca().xaxis.set_major_locator(mdates.DayLocator(interval=2))
plt.gca().xaxis.set_major_formatter(date_format)

# Step 8: Rotate the x-axis date labels for better readability
plt.xticks(rotation=45)

# Step 9: Adjust the layout and display the plot
plt.grid()
plt.tight_layout()
plt.show()