Plot multiple lines with plotnine

518 Views Asked by At

Please consider this code for plotting multiple lines:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

x = [1,2,3]
y = [ [30, 4, 50], [300,400,500], [350,450,550] ]
plt.plot(x, y)

that produces:

enter image description here

I could not figure out how to do it in plotnine. So I asked a famous LLM, received a complex answer that I simplified as follows:

import numpy as np
import plotnine as p9
import pandas as pd
import matplotlib.pyplot as plt

xx = np.array(x * len(y))
yy = np.ravel(y)
yyy = [val for sublist in y for val in sublist]
gg = [i+1 for i in range(len(y)) for _ in range(len(x))]
data = pd.DataFrame({'x':xx, 'y':yy, 'gg':gg})

plot = (
    p9.ggplot(data, p9.aes(x='x', y='y', color='factor(gg)')) +
    p9.geom_line() 
)
plot.draw(True)

that produces: enter image description here

The two images are different and the correct one is the first, built by matplotlib.

So the question: how am I supposed to do this simple plot with plotnine?

2

There are 2 best solutions below

1
Quang Hoang On BEST ANSWER

Another reason not to use the infamous LLM. Here's what you can do:

xx = np.repeat(x, len(y))
yy = np.ravel(y)
gg = np.tile(np.arange(len(y[0])), len(x))
data = pd.DataFrame({'x':xx, 'y':yy, 'gg':gg})

Or with pure pandas like:

data = (pd.DataFrame(y, index=x)       # your data
        .stack()                     # to long form for p9
        .rename_axis(['x','gg',])    # rename the `x` and group
        .reset_index(name='y')       # rename the `y`
     )

Then

plot = (
    p9.ggplot(data, p9.aes(x='x',y='y',color='factor(gg)')) + 
    p9.geom_line()
)
plot.draw(True)

Output:

enter image description here

0
has2k1 On

It looks hard because the data (x and y) are given in a form that is not tidy. While matplotlib accepts that form, looking only at the data it is not clear what rules it uses i.e. how it expands the number of values in x to match y.

If you are using plotnine, the variables (columns) and the observations across them (rows) is explicit.

import pandas as pd
import numpy as np
from plotnine import ggplot, aes, geom_line

x = [1, 2, 3]
y = [[30, 4, 50], [300, 400, 500], [350, 450, 550]]

df = pd.DataFrame({
    "x": np.repeat(x, 3),
    "y": np.ravel(y),
    "g": list("abc" * 3)
})

print(df)

(ggplot(df, aes("x", "y", color="g"))
+ geom_line()
)
   x    y  g
0  1   30  a
1  1    4  b
2  1   50  c
3  2  300  a
4  2  400  b
5  2  500  c
6  3  350  a
7  3  450  b
8  3  550  c

enter image description here