I want to create an Encoder-Decoder Model using the following structure:
- Bert-base-uncased for encoding the input (https://huggingface.co/google-bert/bert-base-uncased)
- Linear layer for connecting the two models using the CLS token of Bert as input
- OPT-125M for decoding using the output of the linear layer as input (https://huggingface.co/facebook/opt-125m)
I want to do this to basically implement the idea I read about in the In-Context Autoencoder paper and test it out myself (https://arxiv.org/abs/2307.06945)
I would like to do this with the huggingface library using PyTorch as it helps to minimize the programming efforts a lot and because I do not know where I would even get the raw implementations of the OPT-125M or BERT model and how to implement them by hand. Also the optimization of huggingface plays a big role to try it on a normal desktop-PC.
My problem is that the OPT-125M model uses a tokenizer for inputs and I am not able to bypass this.
Does anyone know of a way to directly input the output of the linear layer into OPT-125M without encoding it, or a different way of implementing it other than huggingface which is also as performant?
This is the skeleton code that I have already written which produces an error because of the wrong input to OPT:
from transformers import BertTokenizer, BertModel, AutoModelForCausalLM
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
OPT = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
import torch
from torch import nn
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.model = BertModel.from_pretrained('bert-base-uncased')
def forward(self, input_text):
inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
outputs = self.model(**inputs)
return outputs.last_hidden_state[:, 0, :] # CLS token embeddings
class LinearTransformation(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearTransformation, self).__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
def forward(self, x):
# Assuming x is prepared correctly for the OPT model
output = self.model(input_ids=x)
return output
class BertOptPipeline(nn.Module):
def __init__(self):
super(BertOptPipeline, self).__init__()
self.encoder = Encoder()
self.linear_transformation = LinearTransformation(768, 512)
self.decoder = Decoder()
def forward(self, input_text):
encoded = self.encoder(input_text)
transformed = self.linear_transformation(encoded)
print(transformed.shape)
# Further processing may be needed here to match the decoder's input requirements
decoded = self.decoder(transformed)
return decoded
pipeline = BertOptPipeline()
input_text = "thank you for your help"
output = pipeline(input_text)
Thanks for your help!