secondary_xaxis with global variable in subplot

42 Views Asked by At

I'm trying to make a figure that is showing some Doppler velocities of different spectra, but the script does not seem to like the fact that i am changing a global variable. Is there a way around this? Basically, it only plots the seondary axis for the latest value of the global variable, see figure below where the top one does not even have a 0. I guess that it retroactively changes the previous plots somehow.

The reason there is a glob, is because I could not find a way to give that value to the function without crashing the secondary_xaxis function.

enter image description here

Minimal working example:

def doppler(wavelengths):
    c = 299792.458 # speed of light in km/s
    lambda_0 = linecore # central wavelength in Angstrom
    doppler_shifts = c * ((wavelengths-lambda_0) / lambda_0)
    return doppler_shifts

def idoppler(doppler_shifts):
    c = 299792.458 # speed of light in km/s
    lambda_0 = linecore # central wavelength in Angstrom
    wavelengths = lambda_0 * (1 + doppler_shifts / c)-linecore
    return wavelengths

global linecore

plt.subplot(221)
plt.plot(np.linspace(-1,1,10)+6000, np.random.random([10]))
linecore = 6000
ax1 = plt.gca()  # Get the current axis (i.e., the one just created)
ax1a = ax1.secondary_xaxis('top', functions=(doppler, idoppler))
ax1a.set_xticks([-50,0,50])

plt.subplot(222)
plt.plot(np.linspace(-1,1,10)+6000, np.random.random([10]))
linecore = 6000
ax2 = plt.gca()  # Get the current axis (i.e., the one just created)
ax2a = ax2.secondary_xaxis('top', functions=(doppler, idoppler))
ax2a.set_xticks([-50,0,50])

plt.subplot(223)
plt.plot(np.linspace(-1,1,10)+8000, np.random.random([10]))
linecore = 8000
ax3 = plt.gca()  # Get the current axis (i.e., the one just created)
ax3a = ax3.secondary_xaxis('top', functions=(doppler, idoppler))
ax3a.set_xticks([-50,0,50])

plt.subplot(224)
plt.plot(np.linspace(-1,1,10)+8000, np.random.random([10]))
linecore = 8000
ax4 = plt.gca()  # Get the current axis (i.e., the one just created)
ax4a = ax4.secondary_xaxis('top', functions=(doppler, idoppler))
ax4a.set_xticks([-50,0,50])

plt.tight_layout()
plt.show()
1

There are 1 best solutions below

0
On

You should use currying.

enter image description here

import matplotlib.pyplot as plt
import numpy as np

C = 299792.458 # speed of light in km/s

def doppler_shifts(wavelengths, λ0):
    return C*wavelengths/λ0 - C

def wavelenghts(doppler_shifts, λ0):
    return λ0*doppler_shifts/C + λ0

y = iter([np.random.random(10) for _ in range(4)])

fig, ax_2D = plt.subplots(2,2, layout='constrained')

for ax_row, λ0 in zip(ax_2D, (6000, 8000)):
    wl = λ0+np.linspace(-1, 1, 10)
    # here comes the currying
    # note we must trick lambda, otherwise late binding of λ0 is a problem
    f0 = lambda wl, λ0=λ0: doppler_shifts(wl, λ0)
    f1 = lambda ds, λ0=λ0: wavelenghts(ds, λ0)     
    for ax1 in ax_row:
        ax1.plot(wl, next(y))
        ax2 = ax1.secondary_xaxis('top', functions=(f0, f1))
        ax1.set_xlabel('Wave Lenghts')
        ax2.set_xlabel('Doppler Shifts')
        
plt.show()