I have this code that works fine.
import torch
import torch.nn as nn
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# attention module for FRU
class attention_FRU(nn.Module):
def __init__(self, num_channels_down, pad='reflect'):
super(attention_FRU, self).__init__()
# layers to generate conditional convolution weights
self.gen_se_weights1 = nn.Sequential(
nn.Conv2d(num_channels_down, num_channels_down, 1, padding_mode=pad),
nn.LeakyReLU(0.2, inplace=True), # Dont use Softplus here
nn.Sigmoid())
# create conv layers
self.conv_1 = nn.Conv2d(num_channels_down, num_channels_down, 1, padding_mode=pad)
self.norm_1 = nn.BatchNorm2d(num_channels_down, affine=False)
self.actvn = nn.LeakyReLU(0.2, inplace=True)
# self.actvn = nn.Softplus()
def forward(self, guide, x):
se_weights1 = self.gen_se_weights1(guide)
dx = self.conv_1(x)
dx = self.norm_1(dx)
dx = torch.mul(dx, se_weights1)
out = self.actvn(dx)
return out
class hs_net(nn.Module):
def __init__(self, ym_channel, yh_channel, num_channels_down, num_channels_up, num_channels_skip,
filter_size_down, filter_size_up, filter_skip_size):
super(hs_net,self).__init__()
self.FRU = attention_FRU(num_channels_down)
self.up_bic = nn.Upsample(scale_factor=4, mode='bicubic')
self.up_trans = nn.ConvTranspose2d(yh_channel,yh_channel,filter_size_down,stride=4,padding=1)
self.guide_ms = nn.Sequential(
nn.Conv2d(ym_channel, num_channels_down, filter_size_down, padding ='same',padding_mode='reflect'),
nn.BatchNorm2d(num_channels_down),
# nn.LeakyReLU(0.2))
nn.Softplus())
self.enc = nn.Sequential(
nn.Conv2d(num_channels_down, num_channels_down, filter_size_down,padding='same', padding_mode='reflect'),
nn.BatchNorm2d(num_channels_down),
# nn.LeakyReLU(0.2))
nn.Softplus())
self.skip = nn.Sequential(
nn.Conv2d(num_channels_down, num_channels_skip, filter_skip_size, padding ='same', padding_mode='reflect'),
nn.BatchNorm2d(num_channels_skip),
# nn.LeakyReLU(0.2))
nn.Softplus())
self.dc = nn.Sequential(
nn.Conv2d((num_channels_skip + num_channels_up), num_channels_up, filter_size_up,padding='same',padding_mode='reflect'),
nn.BatchNorm2d(num_channels_up),
# nn.LeakyReLU(0.2))
nn.Softplus())
self.out_layer = nn.Sequential(
nn.Conv2d(num_channels_up, yh_channel, 1, padding_mode='reflect'),
nn.Sigmoid())
self.conv_hs = nn.Sequential(
nn.Conv2d(yh_channel,num_channels_down,filter_size_down, padding = 'same',padding_mode = 'reflect'))
# nn.BatchNorm2d(num_channels_down),
# nn.Softplus())
self.conv_bn = nn.Sequential(
nn.Conv2d(num_channels_down,num_channels_down,filter_size_down, padding = 'same',padding_mode = 'reflect'),
# nn.BatchNorm2d(num_channels_down),
# # nn.LeakyReLU(0.2))
nn.Softplus())
self.ym_channels= ym_channel
def forward(self, inputs):
ym = inputs[:, :self.ym_channels, :, :]
yh = inputs[:, self.ym_channels:, :, :]
ym_en0 = self.guide_ms(ym)
ym_en1 = self.enc(ym_en0)
ym_en2 = self.enc(ym_en1)
ym_en3 = self.enc(ym_en2)
ym_en4 = self.enc(ym_en3)
ym_dc0 = self.enc(ym_en4)
ym_dc1 = self.enc(ym_dc0)
ym_dc2 = self.dc(torch.cat((self.skip(ym_en4), ym_dc1), dim=1))
ym_dc3 = self.dc(torch.cat((self.skip(ym_en3), ym_dc2), dim=1))
ym_dc4 = self.dc(torch.cat((self.skip(ym_en2), ym_dc3), dim=1))
ym_dc5 = self.dc(torch.cat((self.skip(ym_en1), ym_dc4), dim=1))
ym_dc6 = self.dc(torch.cat((self.skip(ym_en0), ym_dc5), dim=1))
yh_6 = self.FRU(self.conv_hs(yh), ym_dc0)
yh_7 = self.FRU(self.conv_bn(yh_6), ym_dc1)
yh_8 = self.FRU(self.conv_bn(yh_7), ym_dc2)
yh_9 = self.FRU(self.conv_bn(yh_8), ym_dc3)
yh_10 = self.FRU(self.conv_bn(yh_9), ym_dc4)
yh_11 = self.FRU(self.conv_bn(yh_10), ym_dc5)
yh_12 = self.FRU(self.conv_bn(yh_11), ym_dc6)
out = self.out_layer(yh_12)
return out
class MyDataGenerator(torch.utils.data.Dataset):
def __init__(self, data, batch_size):
self.data = torch.squeeze(data)
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(self.data.shape[0] / self.batch_size))
def __getitem__(self, index):
# Calculate the start and end indices of the batch
start_idx = index * self.batch_size
end_idx = min((index + 1) * self.batch_size, self.data.shape[0])
# Select samples for the batch
batch_data = self.data[start_idx:end_idx]
return batch_data
inputs = torch.from_numpy(np.random.rand(1, 181, 512, 512)).to(device,dtype=torch.float)
num_iter = 1000
LR = 0.001
n_channels=172
net=hs_net(ym_channel=9,
yh_channel=n_channels,
num_channels_down=16,
num_channels_up=16,
num_channels_skip=16,
filter_size_down=1,
filter_size_up=1,
filter_skip_size=1).to(device)
msi = torch.from_numpy(np.random.rand(9, 512, 512)).to(device,dtype=torch.float)
hsi = torch.from_numpy(np.random.rand(172, 128, 128)).to(device,dtype=torch.float)
targets=[msi,hsi]
gband=msi.shape[0]
optimizer = torch.optim.Adam(net.parameters(), lr=LR, eps=1e-3, amsgrad=True)
def hs_loss(model,inputs,targets):
ym = targets[0]
yh=targets[1]
xhat=model(inputs)
return xhat
for it in range(num_iter):
optimizer.zero_grad()
loss, out_HR = hs_loss(net, inputs, targets)
but when I try to use batches:
train_set = MyDataGenerator(inputs, batch_size=4)
data_generator = torch.utils.data.DataLoader(train_set)
for it in range(num_iter):
optimizer.zero_grad()
for batch in data_generator:
loss, out_HR = hs_loss(net, batch, targets)
it gives me:
Given groups=1, weight of size [16, 9, 1, 1], expected input[1, 4, 512, 512] to have 9 channels, but got 4 channels instead
If I try :
train_set = MyDataGenerator(inputs, batch_size=9)
I receive:
3D or 4D (batch mode) tensor expected for input, but got: [ torch.cuda.FloatTensor{1,0,512,512} ]
The issue is that you are not using
torch.data.utils.Dataset, please read the documentation page for more information. You don't have to worry about assembling the batch yourself, the point is for your dataset's__getitem__to return a single element at a time. It's the job oftorch.data.utils.DataLoaderto collate the data properly depending on a batch size. Here is a demonstration following your example:First, define dummy data (make sure the number of elements is larger than 1 of course). Then initialize the dataset and wrap it with a data loader:
Now, you can iterate using dataloader which provides a sampler to navigate through the dataset: