I am trying to build an attention model but Relu and ShiftRight layer by default nested inside the Serial Combinator. This further gives me errors in training.
layer_block = tl.Serial(
tl.Relu(),
tl.LayerNorm(), )
x = np.array([[-2, -1, 0, 1, 2],
[-20, -10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x)) y = layer_block(x)
print(f'layer_block: {layer_block}')
Output
layer_block: Serial[
Serial[
Relu
]
LayerNorm
]
Expected Output
layer_block: Serial[
Relu
LayerNorm
]
The same problem arises with tl.ShiftRight()
The code above is taken from official documentation Example 5
Thanks in advance
I could not found the exact solution to the above problem, but you can create a custom Function using tl.Fn() and add the Relu and ShiftRight function code in it.
Output