I cannot put PyTorch model to device (.to(device))

69 Views Asked by At

So I was writing my first ever autoencoder, here is the code (it can be a little bit goofy, but I believe I written all of it right):

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        self.flatten = nn.Flatten()
        
        self.enc_conv0 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(64),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(128)
        )
        
        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(256),

            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(512)
        )
        
        self.enc_fc = nn.Sequential(
            nn.Linear(in_features=512*64*64, out_features=4096),
            nn.ReLU(),
            nn.BatchNorm1d(4096),
            
            nn.Linear(in_features=4096, out_features=2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            
            nn.Linear(in_features=2048, out_features=dim_code)
        )
        
        self.dec_fc = nn.Sequential(
            nn.Linear(in_features=dim_code, out_features=2048),
            nn.ReLU(),
            nn.BatchNorm1d(2048),
            
            nn.Linear(in_features=2048, out_features=4096),
            nn.ReLU(),
            nn.BatchNorm1d(4096),
            
            nn.Linear(in_features=4096, out_features=512*64*64),
            nn.ReLU(),
            nn.BatchNorm1d(512*64*64)
        )
        
        self.dec_conv0 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=(3,3), padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=(3,3), padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
        )
        
        self.dec_conv1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3,3), padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            
            nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=(3,3), padding=1)
        )

    def forward(self, x):
        e0 = self.enc_conv0(x)
        e1 = self.enc_conv1(e0)
        latent_code = self.enc_fc(self.flatten(e1))
        
        d0 = self.dec_fc(latent_code)
        d1 = self.dec_conv0(d0.view(-1, 512, 64, 64))
        reconstruction = self.dec_conv1(d1)

        return reconstruction, latent_code

And then I was preparing to train it with the next cell of code:

`device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

criterion = nn.BCELoss()
print('crit')

autoencoder = Autoencoder().to(device)
print('deviced')`

Cell prints: cuda 'crit'

And then just stalks infinitely, filling the RAM and CPU at its full (im doing everything on kaggle notebook). And I dont get why. :(

Tried to launch the same notebook in Google colab instead of Kaggle, but it just crashed with error about trying to allocate resources that are not accessable.

Also I thought the issue could had something to do with the first line after initiation of a class, so I replaced

def __init__(self):
        super().__init__()

with

def __init__(self):
        super(Autoencoder, self).__init__()

like I saw in some tutorials (honestly I don't know what this lines do, it just written in every other similar cases) But it also didnt worked

3

There are 3 best solutions below

0
Hadi Daman On

Here's an updated version of your training code incorporating the suggestions:

import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

criterion = nn.MSELoss()  # Use Mean Squared Error for image reconstruction
print('crit')

autoencoder = Autoencoder().to(device)
print('deviced')

If the problem persists, try addressing the points mentioned above and let me know if you encounter any specific errors or if you have additional details about the issue.

0
Nikita On

So the issue was the size of the model, once I tried to make smaller one all the problems disappeared.

0
frad On

I see you've already found out the problem, but this answer could still be helpful to others."

First count how many parameters each layer of your model has:

  • Linear(in_features=N, out_features=M) has MxN weights;
  • Flatten and ReLU have no weights;
  • Both BatchNorm1d(C) and BatchNorm2d(C) have 2xC weights (2 per channel);
  • Both Conv2d and ConvTranspose2d(in_features=N, out_features=M, kernel_size=K, ...) have M different filters, each of size NzKxK. Each filter has also (by default) its own bias (another weight). Therefore, for Conv2d layers, you end up with MxNxKxK + M weights.

If you compute the total number of weights of your model, you should end up with 8,599,888,384 + 2048 x dim_code + 8,604,081,155 + 2048 x dim_code = 17,203,969,539 + 4,096 x dim_code different parameters (I hope I dind't miscalculate anything! I leave the computations in the bottom of the answer). This is even larger than most recent LLMs, such as Mistral7B (which, as the name itself suggests, has around 7B parameters).

Now, considering that PyTorch defaults to a float32 data type for tensors, your model needs more than 17B x 32bit = 64 GiB of RAM. This calculation disregards the contribution from the 4,096 x dim_code term, assuming that dim_code is relatively small in comparison. Therefore, make sure your machine has enough RAM (if you are using CPU as device, or enough VRAM, if you use a GPU).

Number of parameters calculation

Encoder has 8,599,888,384 + 2048 x dim_code:

  • enc_conv0 has:
    • Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1) has 3x64x3x3+64 = 1,792 weights;
    • BatchNorm2d(64) has 2x64 = 128 weights;
    • Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1) has 128x64x3x3+128 = 73,856 weights;
    • BatchNorm2d(128) has 2x128 = 256 weights;
  • enc_conv1 has:
    • Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1) has 256x128x3x3+256 = 295,168 weights;
    • BatchNorm2d(256) has 2x256 = 512 weights;
    • Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1) has 512x256x3x3+512 = 1,180,160 weights;
    • BatchNorm2d(512) has 2x512 = 1,024 weights
  • enc_fc has:
    • Linear(in_features=512*64*64, out_features=4096) has 512x64x64x4096 = 8,589,934,592 weights;
    • BatchNorm1d(4096) has 2x4096 = 8,192 weights;
    • Linear(in_features=4096, out_features=2048) has 4096x2048 = 8,388,608 weights;
    • BatchNorm1d(2048) has 4096 weights;
    • Linear(in_features=2048, out_features=d) has 2048xd weights;

Decoder has 8,604,081,155 + 2048 x dim_code weights:

  • dec_fc has:
    • Linear(in_features=dim_code, out_features=2048) has dim_code x 2048 weights;
    • BatchNorm1d(2048) has 2x2048 = 4,096 weights;
    • Linear(in_features=2048, out_features=4096) has 2048x4096 = 8,388,608 weights;
    • BatchNorm1d(4096) has 2x4096 = 8,192 weights;
    • Linear(in_features=4096, out_features=512*64*64) has 4096x512x64x64 = 8,589,934,592 weights;
    • BatchNorm1d(512*64*64) has 2x512x64x64 = 4,194,304 weights;
  • dec_conv0 has:
    • ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, padding=1) has 256x512x3x3 + 256 = 1,179,904 weights;
    • BatchNorm2d(256) has 512 weights;
    • ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, padding=1) has 128x256x3x3 + 128 = 295,040 weights;
    • BatchNorm2d(128) has 2x128 = 256 weights;
  • dec_conv1 has :
    • ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, padding=1) has 64x128x3x3 + 64 = 73,792 weights;
    • BatchNorm2d(64) has 2x64 = 128 weights;
    • ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=3, padding=1) has 3x64x3x3 + 3 = 1,731 weights.