Probelm in semi-supervised learning of CNN

273 Views Asked by At

I conducted semi-supervised learning to label the unlabelled image in dataset. By utilizing the unlabelled image as input, the CNN model will product a probs index after softmax calculation. If the value over certain number(0.65 for example), I will label the image and add it into train set. The code to obtain persudo-dataset:

def get_pseudo_labels(trainset, dataset, model, threshold=0.65):
# This functions generates pseudo-labels of a dataset using given model.
# It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.
device = "cuda" if torch.cuda.is_available() else "cpu"

# Construct a data loader.
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# The dataset is unlabelled image

# Make sure the model is in eval mode.
model.eval()
# Define softmax function.
softmax = nn.Softmax(dim=-1)

# Iterate over the dataset by batches.
for batch in tqdm(data_loader):

    img, labels = batch
    # Forward the data
    # Using torch.no_grad() accelerates the forward process.
    with torch.no_grad():
        logits = model(img.to(device))

    # Obtain the probability distributions by applying softmax on logits.
    probs = softmax(logits)
    # calculate probs

    for j in range(0, batch_size):
        for i in range(0, 11):
            if probs[j][i].item() > threshold:
                batch[1][j] = torch.Tensor([i]) # Label the imgae
                temp = batch[0][j] + batch[1][j] # contact two tensor
                trainset = ConcatDataset([trainset, temp]) # add this labelled image into trainset

model.train()
return trainset

The complier reminded me as:

if probs[j][i].item() > threshold:

IndexError: index 2 is out of bounds for dimension 0 with size 2

However, I can print the probs normally.

        for j in range(0, batch_size):
        for i in range(0, 11):
            print('batch:', j)
            print('The value of label', i)
            print(probs[j][i])
            if probs[j][i].item() > threshold:
                batch[1][j] = torch.Tensor([i])
                temp = batch[0][j] + batch[1][j]
                trainset = ConcatDataset([trainset, temp])

Output:

...
batch: 63
The value of label 9
tensor(0.0859, device='cuda:0')
batch: 63
The value of label 10
tensor(0.0977, device='cuda:0')

I don't know what the IndexError means....

The fotmat of the img is:

tensor([...(img)],[...(label)])
1

There are 1 best solutions below

1
On

Make sure that dataset % batch_size = 0.

For batch_size (for example 4) you should have number of examples in the dataset (8 or 12 or 16 and so on).

Here 16 % 4 = 0