Transformer Image captioning model produces just padding rather than a caption

722 Views Asked by At

I am trying to produce a model that will produce a caption for an image using resnet as the encoder, transformer as the decoder and COCO as the database.

After training my model for 10 epochs, my model failed to produce anything other than the word <pad> which implies that the only result after going through the model only produced tokens of 0 which corresponds to <pad>.

After using the debugger it seems that the error occurs at the argmax, where the output just becomes zero rather than anything else, but I don't know how to fix it, is it an issue with my model, or the way it is trained?

I based my model off of this github if it helps.

The script to download the COCO model is here:

Download.sh

mkdir data
wget http://msvocds.blob.core.windows.net/annotations-1-0-3/captions_train-val2014.zip -P ./data/
wget http://images.cocodataset.org/zips/train2014.zip -P ./data/
wget http://images.cocodataset.org/zips/val2014.zip -P ./data/

unzip ./data/captions_train-val2014.zip -d ./data/
rm ./data/captions_train-val2014.zip
unzip ./data/train2014.zip -d ./data/
rm ./data/train2014.zip 
unzip ./data/val2014.zip -d ./data/ 
rm ./data/val2014.zip 

Any help is much appreciated.

Here is my code:

model.py

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
import torchvision.models as models
from torch.nn import TransformerDecoderLayer, TransformerDecoder
from torch.nn.utils.rnn import pack_padded_sequence
from torch.autograd import Variable

class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-2])
        self.conv1 = nn.Conv2d(2048, embed_size, 1)
        self.embed_size = embed_size

        self.fine_tune()
        
    def forward(self, images):
        features = self.resnet(images)
        batch_size, _,_,_ = features.shape
        features = self.conv1(features)
        features = features.view(batch_size, self.embed_size, -1)
        features = features.permute(2, 0, 1)

        return features

    def fine_tune(self, fine_tune=True):
        for p in self.resnet.parameters():
            p.requires_grad = False
        # If fine-tuning, only fine-tune convolutional blocks 2 through 4
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

class PositionEncoder(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionEncoder, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
    
class Embedder(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
    def forward(self, x):
        return self.embed(x)


class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, h, num_hidden, N, device, dropout_dec=0.1, dropout_pos=0.1):
        super(Transformer, self).__init__()
        decoder_layers = TransformerDecoderLayer(d_model, h, num_hidden, dropout_dec)
        self.source_mask = None
        self.device = device
        self.d_model = d_model
        self.pos_decoder = PositionalEncoder(d_model, dropout_pos)
        self.decoder = TransformerDecoder(decoder_layers, N)
        self.embed = Embedder(vocab_size, d_model)
        self.linear = nn.Linear(d_model, vocab_size)

        self.init_weights()

    def forward(self, source, mem):
        source = source.permute(1,0) 
        if self.source_mask is None or self.source_mask.size(0) != len(source):
            self.source_mask = nn.Transformer.generate_square_subsequent_mask(self=self, sz=len(source)).to(self.device)

        source = self.embed(source) 
        source = source*math.sqrt(self.d_model)  
        source = self.pos_decoder(source)
        output = self.decoder(source, mem, self.source_mask)
        output = self.linear(output)
        return output

    def init_weights(self):
        initrange = 0.1
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def pred(self, memory, pred_len):
        batch_size = memory.size(1)
        src = torch.ones((pred_len, batch_size), dtype=int) * 2
        if self.source_mask is None or self.source_mask.size(0) != len(src):
            self.source_mask = nn.Transformer.generate_square_subsequent_mask(self=self, sz=len(src)).to(self.device)
        output = torch.ones((pred_len, batch_size), dtype=int)
        src, output = src.cuda(), output.cuda()
        for i in range(pred_len):
            src_emb = self.embed(src) # src_len * batch size * embed size
            src_emb = src_emb*math.sqrt(self.d_model)
            src_emb = self.pos_decoder(src_emb)
            out = self.decoder(src_emb, memory, self.source_mask)
            out = out[i]
            out = self.linear(out) # batch_size * vocab_size
            out = out.argmax(dim=1)
            if i < pred_len-1:
                src[i+1] = out
            output[i] = out
        return output

Data_Loader.py

import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
import pickle
import numpy as np
import nltk
from PIL import Image
from build_vocab import Vocabulary
from pycocotools.coco import COCO


class CocoDataset(data.Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""
    def __init__(self, root, json, vocab, transform=None):
        """Set the path for images, captions and vocabulary wrapper.
        
        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.root = root
        self.coco = COCO(json)
        self.ids = list(self.coco.anns.keys())
        self.vocab = vocab
        self.transform = transform

    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        coco = self.coco
        vocab = self.vocab
        ann_id = self.ids[index]
        caption = coco.anns[ann_id]['caption']
        img_id = coco.anns[ann_id]['image_id']
        path = coco.loadImgs(img_id)[0]['file_name']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # Convert caption (string) to word ids.
        tokens = nltk.tokenize.word_tokenize(str(caption).lower())
        caption = []
        caption.append(vocab('<start>'))
        caption.extend([vocab(token) for token in tokens])
        caption.append(vocab('<end>'))
        target = torch.Tensor(caption)
        return image, target

    def __len__(self):
        return len(self.ids)


def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).
    
    We should build custom collate_fn rather than using default collate_fn, 
    because merging caption (including padding) is not supported in default.
    Args:
        data: list of tuple (image, caption). 
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.
    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]        
    return images, targets, lengths

def get_loader(root, json, vocab, transform, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader for custom coco dataset."""
    # COCO caption dataset
    coco = CocoDataset(root=root,
                       json=json,
                       vocab=vocab,
                       transform=transform)
    
    # Data loader for COCO dataset
    # This will return (images, captions, lengths) for each iteration.
    # images: a tensor of shape (batch_size, 3, 224, 224).
    # captions: a tensor of shape (batch_size, padded_length).
    # lengths: a list indicating valid length for each caption. length is (batch_size).
    data_loader = torch.utils.data.DataLoader(dataset=coco, 
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn)
    return data_loader

Build_vocab.py

import nltk
import pickle
import argparse
from collections import Counter
from pycocotools.coco import COCO

class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]


    def __len__(self):
        return len(self.word2idx)

def build_vocab(json, threshold):
    coco = COCO(json)
    counter = Counter()
    ids = coco.anns.keys()
    for i, id in enumerate(ids):
        caption = str(coco.anns[id]['caption'])
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        counter.update(tokens)

        if (i+1) % 1000 == 0:
            print("[{}/{}] Tokenized the captions.".format(i+1, len(ids)))

    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab

def main(args):
    vocab = build_vocab(json=args.caption_path, threshold=args.threshold)
    vocab_path = args.vocab_path
    with open(vocab_path, 'wb') as f:
        pickle.dump(vocab, f)
    print("Total vocabulary size: {}".format(len(vocab)))
    print("Saved the vocabulary wrapper to '{}'".format(vocab_path))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--caption_path', type=str, 
                        default='./data/annotations/captions_train2014.json', 
                        help='path for train annotation file')
    parser.add_argument('--vocab_path', type=str, default='./data/vocab.pkl', 
                        help='path for saving vocabulary wrapper')
    parser.add_argument('--threshold', type=int, default=4, 
                        help='minimum word count threshold')
    args = parser.parse_args()
    main(args)

train.py

import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
import math
from tqdm import tqdm
from data_loader import get_loader 
from build_vocab import Vocabulary
from model import EncoderCNN, Decoder
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def main(args):
    batch_size = 64
    embed_size = 512
    num_heads = 8
    num_layers = 6
    num_workers = 2
    num_epoch = 5
    lr = 1e-3
    load = False
    # Create model directory
    if not os.path.exists('models/'):
        os.makedirs('models/')
    
    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([ 
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])
    
    # Load vocabulary wrapper
    with open('data/vocab.pkl', 'rb') as f:
        vocab = pickle.load(f)
    
    # Build data loader
    data_loader = get_loader('data/resized2014', 'data/annotations/captions_train2014.json', vocab, 
                             transform, batch_size,
                             shuffle=True, num_workers=num_workers) 

    encoder = EncoderCNN(embed_size).to(device)
    encoder.fine_tune(False)
    decoder = Decoder(len(vocab), embed_size, num_heads, embed_size, num_layers).to(device)
    
    if(load):
        encoder.load_state_dict(torch.load(os.path.join('models/', 'encoder-{}-{}.ckpt'.format(5, 5000))))
        decoder.load_state_dict(torch.load(os.path.join('models/', 'decoder-{}-{}.ckpt'.format(5, 5000))))
        print("Load Successful")

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    encoder_optim = torch.optim.Adam(encoder.parameters(), lr=lr)
    decoder_optim = torch.optim.Adam(decoder.parameters(), lr=lr)
    
    # Train the models
    for epoch in range(num_epoch):
        encoder.train()
        decoder.train()
        for i, (images, captions, lengths) in tqdm(enumerate(data_loader), total=len(data_loader), leave=False):
            
            # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)

            # Forward, backward and optimize
            features = encoder(images)
            cap_input = captions[:, :-1]
            cap_target = captions[:, 1:]
            outputs = decoder(cap_input, features)
            outputs = outputs.permute(1,0,2)
            outputs_shape = outputs.reshape(-1, len(vocab))
            loss = criterion(outputs_shape, cap_target.reshape(-1))
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            encoder_optim.step()
            decoder_optim.step()
                
            # Save the model checkpoints
            if (i+1) % args.save_step == 0:
                torch.save(decoder.state_dict(), os.path.join(
                    'models/', 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
                torch.save(encoder.state_dict(), os.path.join(
                    'models/', 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_step', type=int , default=10, help='step size for prining log info')
    parser.add_argument('--save_step', type=int , default=1000, help='step size for saving trained models')
        
    args = parser.parse_args()
    print(args)
    main(args)

sample.py


import torch
import matplotlib.pyplot as plt
import numpy as np 
import argparse
import pickle 
import os
from torchvision import transforms 
from build_vocab import Vocabulary
from data_loader import get_loader 
from model import EncoderCNN, Decoder
from PIL import Image


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#

def token_sentence(decoder_out, itos):
    tokens = decoder_out
    tokens = tokens.transpose(1, 0)
    tokens = tokens.cpu().numpy()
    results = []
    for instance in tokens:
        result = ' '.join([itos[x] for x in instance])
        results.append(''.join(result.partition('<eos>')[0])) # Cut before '<eos>'
    return results

def load_image(image_path, transform=None):
    image = Image.open(image_path).convert('RGB')
    image = image.resize([224, 224], Image.LANCZOS)
    
    if transform is not None:
        image = transform(image).unsqueeze(0)
    
    return image

def main(args):
    batch_size = 64
    embed_size = 512
    num_heads = 8
    num_layers = 6
    num_workers = 2
    
    # Image preprocessing
    transform = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])
    
    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    data_loader = get_loader('data/resized2014', 'data/annotations/captions_train2014.json', vocab, 
                             transform, batch_size,
                             shuffle=True, num_workers=num_workers) 

    # Build models
    encoder = EncoderCNN(embed_size).to(device)
    encoder.fine_tune(False)
    decoder = Decoder(len(vocab), embed_size, num_heads, embed_size, num_layers).to(device)

    # Load trained models
    encoder.load_state_dict(torch.load(os.path.join('models/', 'encoder-{}-{}.ckpt'.format(1, 4000))))
    decoder.load_state_dict(torch.load(os.path.join('models/', 'decoder-{}-{}.ckpt'.format(1, 4000))))
    encoder.eval()
    decoder.eval()
    
    itos = vocab.idx2word
    pred_len = 100
    result_collection = []

    # Decode with greedy
    # with torch.no_grad():
    #     for i, (images, captions, lengths) in enumerate(data_loader):
    #         images = images.to(device)
    #         features = encoder(images)
    #         output = decoder.generator(features, pred_len)
    #         result_caption = token_sentence(output, itos)
    #         result_collection.extend(result_caption)

# Decode with greedy
    with torch.no_grad():
        for batch_index, (inputs, captions, caplens) in enumerate(data_loader):
            inputs, captions = inputs.cuda(), captions.cuda()
            enc_out = encoder(inputs)
            captions_input = captions[:, :-1]
            captions_target = captions[:, 1:]
            output = decoder.pred(enc_out, pred_len)
            result_caption = token_sentence(output, itos)
            result_collection.extend(result_caption)
        
            
    print("Prediction-greedy:", result_collection[1])
    print("Prediction-greedy:", result_collection[2])
    print("Prediction-greedy:", result_collection[3])
    print("Prediction-greedy:", result_collection[4])
    print("Prediction-greedy:", result_collection[5])
    print("Prediction-greedy:", result_collection[6])
    print("Prediction-greedy:", result_collection[7])
    print("Prediction-greedy:", result_collection[8])
    print("Prediction-greedy:", result_collection[9])
    print("Prediction-greedy:", result_collection[10])
    print("Prediction-greedy:", result_collection[11])

    # # Prepare an image
    # image = load_image(args.image, transform)
    # image_tensor = image.to(device)
    
    # # Generate an caption from the image
    # feature = encoder(image_tensor)
    # sampled_ids = decoder.generator(feature, pred_len)
    # sampled_ids = sampled_ids[0].cpu().numpy()          # (1, max_seq_length) -> (max_seq_length)
    
    # # Convert word_ids to words
    # sampled_caption = []
    # for word_id in sampled_ids:
    #     word = vocab.idx2word[word_id]
    #     sampled_caption.append(word)
    #     if word == '<end>':
    #         break
    # sentence = ' '.join(sampled_caption)
    
    # # Print out the image and the generated caption
    # print (sentence)
    # image = Image.open(args.image)
    # plt.imshow(np.asarray(image))
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image', type=str, required=False, help='input image for generating caption')
    parser.add_argument('--vocab_path', type=str, default='data/vocab.pkl', help='path for vocabulary wrapper')
    args = parser.parse_args()
    main(args)

resize.py


import argparse
import os
from PIL import Image


def resize_image(image, size):
    """Resize an image to the given size."""
    return image.resize(size, Image.ANTIALIAS)

def resize_images(image_dir, output_dir, size):
    """Resize the images in 'image_dir' and save into 'output_dir'."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    images = os.listdir(image_dir)
    num_images = len(images)
    for i, image in enumerate(images):
        with open(os.path.join(image_dir, image), 'r+b') as f:
            with Image.open(f) as img:
                img = resize_image(img, size)
                img.save(os.path.join(output_dir, image), img.format)
        if (i+1) % 100 == 0:
            print ("[{}/{}] Resized the images and saved into '{}'."
                   .format(i+1, num_images, output_dir))

def main(args):
    image_dir = args.image_dir
    output_dir = args.output_dir
    image_size = [args.image_size, args.image_size]
    resize_images(image_dir, output_dir, image_size)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_dir', type=str, default='./data/train2014/',
                        help='directory for train images')
    parser.add_argument('--output_dir', type=str, default='./data/resized2014/',
                        help='directory for saving resized images')
    parser.add_argument('--image_size', type=int, default=256,
                        help='size for image after processing')
    args = parser.parse_args()
    main(args)
0

There are 0 best solutions below