Shared secondary axes

3k Views Asked by At

How to set a shared secondary axes using subplots in matplotlib.

Here is the minimal code to display the issue:

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt


def countour_every(ax, every, x_data, y_data,
                   color='black', linestyle='-', marker='o', **kwargs):
    """Draw a line with countour marks at each every points"""
    line, = ax.plot(x_data, y_data, linestyle)
    return line


def prettify_axes(ax, data):
    """Makes my plot pretty"""

    if 'title' in data:
        ax.set_title(data['title'])

    if 'y_lim' in data:
        ax.set_ylim(data['y_lim'])

    if 'x_lim' in data:
        ax.set_xlim(data['x_lim'])

    # Draw legend only if labels were set (HOW TO DO IT?)
    # if ax("has_some_label_set"):
    ax.legend(loc='upper right', prop={'size': 6})

    ax.title.set_fontsize(7)
    ax.xaxis.set_tick_params(labelsize=6)
    ax.xaxis.set_tick_params(direction='in')
    ax.xaxis.label.set_size(7)

    ax.yaxis.set_tick_params(labelsize=6)
    ax.yaxis.set_tick_params(direction='in')
    ax.yaxis.label.set_size(7)


def prettify_second_axes(ax):
    ax.yaxis.set_tick_params(labelsize=7)
    ax.yaxis.set_tick_params(labelcolor='red')
    ax.yaxis.label.set_size(7)


def compare_plot(ax, data):
    line1 = countour_every(ax, 10, **data[0])
    if 'label' in data[0]:
        line1.set_label(data[0]['label'])

    line2 = countour_every(ax, 10, **data[1])
    if 'label' in data[1]:
        line2.set_label(data[1]['label'])

    ax2 = ax.twinx()
    line3 = ax.plot(
            data[0]['x_data'],
            data[0]['y_data']-data[1]['y_data'], '-',
            color='red', alpha=.2, zorder=1)

    prettify_axes(ax, data[0])
    prettify_second_axes(ax2)


d0 = {'x_data': np.arange(0, 10), 'y_data': abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-', 'label': 'd0'}
d1 = {'x_data': np.arange(0, 10), 'y_data': -abs(np.random.random(10)), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '--', 'label': 'd1'}
d2 = {'x_data': np.arange(0, 10), 'y_data': np.random.random(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}
d3 = {'x_data': np.arange(0, 10), 'y_data': -np.ones(10), 'y_lim': [-1, 1], 'color': '.7', 'linestyle': '-.'}

fig, axes = plt.subplots(2, 2, sharex=True, sharey=True)
fig.set_size_inches(6, 6)

compare_plot(axes[0][0], [d0, d1])
compare_plot(axes[0][1], [d0, d2])
compare_plot(axes[1][0], [d1, d0])
compare_plot(axes[1][1], [d3, d2])

fig.suptitle('A comparison chart')
fig.set_tight_layout({'rect': [0, 0.03, 1, 0.95]})
fig.text(0.5, 0.03, 'Position', ha='center')
fig.text(0.005, 0.5, 'Amplitude', va='center', rotation='vertical')
fig.text(0.975, 0.5, 'Error', color='red', va='center', rotation='vertical')

fig.savefig('demo.png', dpi=300)

That generates the following image

Shared axes issue

We can see that the X axis and the Y axis is correctly shared, but the secondary twin axis, is repeated in all subplots.

Also the secondary axis isn't scaling correctly to fit the data. (that should occurs independently of the principal y axis being limited).

1

There are 1 best solutions below

4
On BEST ANSWER

You will need to share the twin axes manually and also remove the ticklabels

def compare_plot(ax, data):
    # ...
    ax2 = ax.twinx()
    # ...
    return ax2

sax1 = compare_plot(axes[0][0], [d0, d1])
sax2 = compare_plot(axes[0][1], [d0, d2])
sax3 = compare_plot(axes[1][0], [d1, d0])
sax4 = compare_plot(axes[1][1], [d3, d2])

for sax in [sax2, sax3, sax4]:
    sax1.get_shared_y_axes().join(sax1, sax)
sax1.autoscale()
for sax in [sax1,sax3]:
    sax.yaxis.set_tick_params(labelright=False)

enter image description here