I am trying to define inference decoder for sequence-to-sequence
prediction task using Encoder-Decoder
architecture that uses MultiHeadAttention
layer (you know, all the good stuff). I'm struggling to define an inference decoder and load trained weights so I'll very much appreciate if you could advise on the best way of doing it. Below is my trained
network summary
:
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
en_input_layer (InputLayer [(None, 12, 1)] 0 []
)
en_bidirect_gru1 (Bidirect (None, 12, 512) 397824 ['en_input_layer[0][0]']
ional)
de_input_layer (InputLayer [(None, 3, 1)] 0 []
)
en_gru2_layer (GRU) [(None, 12, 256), 591360 ['en_bidirect_gru1[0][0]']
(None, 256)]
de_gru1_layer (GRU) [(None, 3, 256), 198912 ['de_input_layer[0][0]',
(None, 256)] 'en_gru2_layer[0][1]']
multi_head_attn_layer (Mul (None, 3, 256) 526080 ['de_gru1_layer[0][0]',
tiHeadAttention) 'en_gru2_layer[0][0]']
attn_source_add_layer (Add (None, 3, 256) 0 ['de_gru1_layer[0][0]',
) 'multi_head_attn_layer[0][0]'
]
before_preds_layer_norm (L (None, 3, 256) 512 ['attn_source_add_layer[0][0]'
ayerNormalization) ]
time_distributed (TimeDist (None, 3, 1) 257 ['before_preds_layer_norm[0][0
ributed) ]']
==================================================================================================
Total params: 1714945 (6.54 MB)
Trainable params: 1714945 (6.54 MB)
Non-trainable params: 0 (0.00 Byte)
__________________________________________________________________________________________________
I have this model trained and it is called enc_de_model
. Now, I'm trying to write an inference decoder that can be recursively called to predict one sequence element at a time but I'm struggling to define it and set the trained decoder weights to it. Below is my attempt to define it.
def define_inf_decoder(context_vec, hsize=256):
decoder_input = Input(shape=(1,1))
decoder_state_input = Input(shape=(hsize,))
de_gru1 = GRU(hsize, return_sequences=True, return_state=True, name='de_gru1_layer')
de_gru1_out, de_state_out = de_gru1(decoder_input, initial_state=decoder_state_input)
de_mha_attn = MultiHeadAttention(num_heads=2, key_dim=hsize, name='multi_head_attn_layer')
attn_out = de_mha_attn(query=de_gru1_out, value=context_vec)
attn_added = Add(name='attn_source_add_layer')([de_gru1_out, attn_out])
h_hat = LayerNormalization(name='before_preds_layer_norm')(attn_added)
### Output Layer
preds = Dense(1, name='output_layer')(h_hat)
decoder_model = Model(inputs=[decoder_input, decoder_state_input], outputs=[preds, de_state_out])
return decoder_model
encoder_output = inf_encoder.predict(testX) ### testX shape: (1416, 12, 1)
inference_decoder = define_inf_decoder(encoder_output[0]) ### encoder_output shape: (1416, 12, 256)
### Set Inference Decoder Weights
trained_layers = [l.name for l in enc_de_model.layers]
print(f"No. of trained layers: {len(trained_layers)}")
for l in inference_decoder.layers:
if l.name in trained_layers:
trained_wts = enc_de_model.get_layer(l.name).get_weights()
if len(trained_wts)>0:
inference_decoder.get_layer(l.name).set_weights(trained_wts)
Since the MultiHeadAttention
layer needs to be supplied a context vector coming from the encoder how do I define this function so that I don't have to set
weights every time I call the inference decoder for prediction? I can assign the context vector of all the samples but that won't be right (or will it?) when I'm decoding single sample at a time. Any suggestions to improve this or to make this better will be very helpful.
If I don't pass context_vec
while defining the inference decoder than it won't let me build the Model
. I'd very much appreciate your help here. Thank you.