How to define Inference Decoder with Multi Head Attention and set trained weights

26 Views Asked by At

I am trying to define inference decoder for sequence-to-sequence prediction task using Encoder-Decoder architecture that uses MultiHeadAttention layer (you know, all the good stuff). I'm struggling to define an inference decoder and load trained weights so I'll very much appreciate if you could advise on the best way of doing it. Below is my trained network summary:

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
==================================================================================================
 en_input_layer (InputLayer  [(None, 12, 1)]              0         []                            
 )                                                                                                
                                                                                                  
 en_bidirect_gru1 (Bidirect  (None, 12, 512)              397824    ['en_input_layer[0][0]']      
 ional)                                                                                           
                                                                                                  
 de_input_layer (InputLayer  [(None, 3, 1)]               0         []                            
 )                                                                                                
                                                                                                  
 en_gru2_layer (GRU)         [(None, 12, 256),            591360    ['en_bidirect_gru1[0][0]']    
                              (None, 256)]                                                        
                                                                                                  
 de_gru1_layer (GRU)         [(None, 3, 256),             198912    ['de_input_layer[0][0]',      
                              (None, 256)]                           'en_gru2_layer[0][1]']       
                                                                                                  
 multi_head_attn_layer (Mul  (None, 3, 256)               526080    ['de_gru1_layer[0][0]',       
 tiHeadAttention)                                                    'en_gru2_layer[0][0]']       
                                                                                                  
 attn_source_add_layer (Add  (None, 3, 256)               0         ['de_gru1_layer[0][0]',       
 )                                                                   'multi_head_attn_layer[0][0]'
                                                                    ]                             
                                                                                                  
 before_preds_layer_norm (L  (None, 3, 256)               512       ['attn_source_add_layer[0][0]'
 ayerNormalization)                                                 ]                             
                                                                                                  
 time_distributed (TimeDist  (None, 3, 1)                 257       ['before_preds_layer_norm[0][0
 ributed)                                                           ]']                           
                                                                                                  
==================================================================================================
Total params: 1714945 (6.54 MB)
Trainable params: 1714945 (6.54 MB)
Non-trainable params: 0 (0.00 Byte)
__________________________________________________________________________________________________

I have this model trained and it is called enc_de_model. Now, I'm trying to write an inference decoder that can be recursively called to predict one sequence element at a time but I'm struggling to define it and set the trained decoder weights to it. Below is my attempt to define it.

def define_inf_decoder(context_vec, hsize=256):
  decoder_input = Input(shape=(1,1))
  decoder_state_input = Input(shape=(hsize,))

  de_gru1 = GRU(hsize, return_sequences=True, return_state=True, name='de_gru1_layer')
  de_gru1_out, de_state_out = de_gru1(decoder_input, initial_state=decoder_state_input)

  de_mha_attn = MultiHeadAttention(num_heads=2, key_dim=hsize, name='multi_head_attn_layer')
  attn_out = de_mha_attn(query=de_gru1_out, value=context_vec)

  attn_added = Add(name='attn_source_add_layer')([de_gru1_out, attn_out])
  h_hat = LayerNormalization(name='before_preds_layer_norm')(attn_added)

  ### Output Layer
  preds = Dense(1, name='output_layer')(h_hat)

  decoder_model = Model(inputs=[decoder_input, decoder_state_input], outputs=[preds, de_state_out])
  return decoder_model


encoder_output = inf_encoder.predict(testX) ### testX shape: (1416, 12, 1)
inference_decoder = define_inf_decoder(encoder_output[0]) ### encoder_output shape: (1416, 12, 256)

### Set Inference Decoder Weights
trained_layers = [l.name for l in enc_de_model.layers]
print(f"No. of trained layers: {len(trained_layers)}")

for l in inference_decoder.layers:
  if l.name in trained_layers:
    trained_wts = enc_de_model.get_layer(l.name).get_weights()
    if len(trained_wts)>0:
      inference_decoder.get_layer(l.name).set_weights(trained_wts)

Since the MultiHeadAttention layer needs to be supplied a context vector coming from the encoder how do I define this function so that I don't have to set weights every time I call the inference decoder for prediction? I can assign the context vector of all the samples but that won't be right (or will it?) when I'm decoding single sample at a time. Any suggestions to improve this or to make this better will be very helpful.

If I don't pass context_vec while defining the inference decoder than it won't let me build the Model. I'd very much appreciate your help here. Thank you.

0

There are 0 best solutions below