I'm working with Trax, a framework built by the Google Brain team to work with deep learning models as an alternative to TensorFlow. As a TensorFlow developer, I'm pretty used to the model.summary()
method (documented here) to display a full model summary, for example:
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 16, 303)] 0
_________________________________________________________________
bidirectional (Bidirectional (None, 16, 256) 442368
_________________________________________________________________
time_distributed (TimeDistri (None, 16, 22) 5654
=================================================================
Total params: 448,022
Trainable params: 448,022
Non-trainable params: 0
Is there something equivalent in Trax?
Currently, there does not appear to be a method similar to
.summary()
in Trax; the closest thing is that you can print the model. Adapting the example from the documentation:Result:
Although nowhere as detailed as Tensorflow's
model.summary()
, there is still useful info in the print output: notice that the parameters of the embedding layer are included in the printout; notice also that, if you change the model's last layer totl.Dense(3)
, the respective output will change toDense_3
.