Python plotly subplots with multiple columns

115 Views Asked by At

I have a code as below, which basically uses plotly and plots multiple columns of dataframe. The below code works fine. But I want use this graph as a subplot. Imagine having 3 rows and 2 columns and wanting to show a similar graph from this function as a subplot.

def custom_graph_objects(df, item, show=False):
    # Create the Plotly figure
    fig = go.Figure()


    # Add Percentage Change trace with markers and labels
    fig.add_trace(go.Scatter(x=df['vpl_start_week'][df['wh_item_code']==item], y=df['list_price'][df['wh_item_code']==item], mode='lines', name='list_price'))

    fig.add_trace(go.Scatter(x=df['vpl_start_week'][df['wh_item_code']==item], y=df['current_sell_zone1'][df['wh_item_code']==item], mode='lines', name='current_sell_zone1'))

    # Add Percentage Change From Initial trace with markers and labels
    fig.add_trace(go.Scatter(x=df['vpl_start_week'][df['wh_item_code']==item], y=df['percentile_difference'][df['wh_item_code']==item], mode='lines',
                            name='percentile_difference'))

    # Set the title and axis labels
    fig.update_layout(title='Item Price Analysis',
                    xaxis_title='vpl_start_date',
                    yaxis_title='price')
    fig.update_traces(mode="lines", hovertemplate=None)
    fig.update_layout(hovermode="x unified")
    if show:
        fig.show()
    return fig

enter image description here

the below is the code which I tried for subplots but it does not reflect what I needed.

r = 2
c = 2
sample_items = random.sample(grand_aggregate['wh_item_code'].unique().tolist(), r*c)
sub_fig = make_subplots(rows=r, cols=c)
pos = [(x+1, y+1) for x in range(r) for y in range(c)]
for i, item in enumerate(sample_items):
    sub_fig.add_trace(go.Scatter(x=grand_aggregate['vpl_start_week'][grand_aggregate['wh_item_code']==item], 
                                 y=grand_aggregate['list_price'][grand_aggregate['wh_item_code']==item], mode='lines', name='list_price'), 
                  row=pos[i][0], col=pos[i][1])
    sub_fig.update_traces(go.Scatter(x=grand_aggregate['vpl_start_week'][grand_aggregate['wh_item_code']==item], 
                                    y=grand_aggregate['current_sell_zone1'][grand_aggregate['wh_item_code']==item], mode='lines', name='current_sell_zone1'), 
                  row=pos[i][0], col=pos[i][1])
    sub_fig.update_traces(go.Scatter(x=grand_aggregate['vpl_start_week'][grand_aggregate['wh_item_code']==item], 
                                    y=grand_aggregate['percentile_difference'][grand_aggregate['wh_item_code']==item], mode='lines', name='percentile_difference'), 
                  row=pos[i][0], col=pos[i][1])
    
sub_fig.update_layout(height=900, width=1200, title_text="Subplots")
sub_fig.show()

enter image description here

any help on this is highly appreciated. Thanks in advance.

1

There are 1 best solutions below

2
On

A subplot in plotly creates one object with make_subplots(rows=r, cols=c) and sets the graph to a subplot, which is one of the matrices of the object. So there is no need for fig=go.Figure() in the function, and fig is required as an argument. The sample data has multiple columns, but I created a function to expand it into subplots, and then created code to complete the whole thing in a loop process.

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

df = px.data.stocks()

df.head()
date    GOOG    AAPL    AMZN    FB  NFLX    MSFT
0   2018-01-01  1.000000    1.000000    1.000000    1.000000    1.000000    1.000000
1   2018-01-08  1.018172    1.011943    1.061881    0.959968    1.053526    1.015988
2   2018-01-15  1.032008    1.019771    1.053240    0.970243    1.049860    1.020524
3   2018-01-22  1.066783    0.980057    1.140676    1.016858    1.307681    1.066561
4   2018-01-29  1.008773    0.917143    1.163374    1.018357    1.273537    1.040708

r,c = 3,2
sample_tickers = df.columns.tolist()[1:]
sub_fig = make_subplots(rows=r, cols=c)
pos = [(x+1, y+1) for x in range(r) for y in range(c)]

def custom_graph(df, tick, sub_fig, pos, show=False):
    sub_fig.add_trace(go.Scatter(
        x=df['date'],
        y=df[tick],
        mode='lines'
    ), row=pos[0], col=pos[1])
    return sub_fig

for i,(tick,p) in enumerate(zip(sample_tickers,pos)):
    sub_fig = custom_graph(df, tick, sub_fig, p, show=False)

sub_fig.update_layout(height=900, width=1200, title_text='Subplots')
sub_fig.show()

enter image description here