[repost from pytorch forums because no one responded.]
Hello, I’m messing around with transformers right now, and I’m trying to modify the encoded representation with a modified LSTM (the goal is to continue text in a specific style). I’ve found an example on how to use T.nn.TransformerEncoder, but no examples on how to properly use T.nn.TransformerDecoder. How am I supposed to use it? I’ve read about how decoders work in general, but I can’t find anything about the specific pytorch implementation. How should I use it for training vs inference? do I manually have to put the output of the transformer into the tgt during inference, or is that done automatically? What does tgt_is_causal do?
I’ve included a snippet of my code if that’s useful at all.
def forward(self, x, mhx, tgt= None):
embedded_seq = self.embedding(x) * math.sqrt(self.emb_dim)
embedded_seq = self.pos_encoder(embedded_seq)
if src_mask is None:
"""Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').
Unmasked positions are filled with float(0.0).
"""
src_mask = nn.Transformer.generate_square_subsequent_mask(len(embedded_seq)).to(device)
encoded_seq = self.transformer_encoder(embedded_seq, src_mask)
# Take the encoded sequence, repeat once over the time axis, and stick that into the DNC. (to give the DNC time to analyze & plan)
processed_seq, (chx, mhx, rv) = self.vector_machine(T.cat( (encoded_seq, encoded_seq),1), (None, mhx, None), reset_experience=True, pass_through_memory=True)
#split the processed sequence into two parts, taking the second half and adding it to the encoded sequence as a skip layer.
processed_seq = T.chunk(processed_seq,2,dim=1)[1] + encoded_seq
#TODO: put decoder here.
return decoded_seq, (chx, mhx, rv)
I pretty much copied the first half of this code from the pytorch transformer encoder example, but I can't find a good example for the encoder & decoder I can look at. All I can find are decoder-only models that don't fit what I'm trying to do. Can someone please give me one example, or at least an explanation of how I should go about this?
I've tried looking at the pytorch documentation for the function, but it doesn't show anything about how to use the decoder. I've tried looking for examples on github, stack overflow, and the pytorch forums, but none of them actually fit with what I'm trying to do.