deespeed getting output shape wrong on stages>1

20 Views Asked by At

I'm trying to create a small hello world example (to also donate back to the community) to exemplify pipeline parallelism using deepspeed.

import torch
from torch import nn
from sklearn.metrics import r2_score
from deepspeed.pipe import PipelineModule
import deepspeed
from torch.nn.parallel import DistributedDataParallel as DDP
import json
from deepspeed.accelerator import get_accelerator
import os


class hello_net(nn.Module):
    def __init__(self):

        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(2,5),
            nn.ReLU(),
            nn.Linear(5,1)
        )

        self.fc = PipelineModule(layers=self.fc, num_stages=2)

    def forward(self, x):
        x = self.fc(x)
        return x


deepspeed.init_distributed()

net = hello_net()

with open('zero_stage2.json') as f_in:
            zero_config = json.load(f_in)

model_engine, optimizer, trainloader, __ = deepspeed.initialize(
    model=net,
    config=zero_config,
)

# Get the local device name (str) and local rank (int).
local_device = get_accelerator().device_name(model_engine.local_rank)
local_rank = model_engine.local_rank


local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

X = torch.rand((1000,2))
x1 = X[:,0]
x2 = X[:,1]
y = x1 * x2
NUM_EPOCHS = 1000
criterion = torch.nn.MSELoss(reduction='mean')

for micro_step in range(NUM_EPOCHS):
    inputs = X.to(model_engine.device)
    labels = y.to(model_engine.device)
    outputs = net(inputs)
    print(f'outputs: {outputs.shape}')
    print(f'labels: {labels.shape}')

    
    loss = criterion(outputs, labels)
    print(f'loss: {loss}')
    model_engine.backward(loss)
    model_engine.step()

Interestingly, output is torch.Size([1000, 5]), in case num_stages=2 on PipelineModule but torch.Size([1000, 1]) in case num_stages=1

To me it seems net(inputs) behaves differently depending on the value of num_stages and in case num_stages=2 returns the result of the first fc layer instead of the last.

Is this an issue in deepspeed or did I get something wrong here? I'm on deepspeed 0.13.1 and python 3.11

0

There are 0 best solutions below