I am trying to build a NN with a dropout layer in case to avoid overfitting. But I met some trouble when I wrote it in Jax Flax.
Here is the original model I built in Pytorch:
class MLPModel(nn.Module):
def __init__(self, layer, dp_rate=0.1):
super().__init__()
layers = []
for idx in range(len(layer) - 1):
layers += [
nn.Linear(layer[idx], layer[idx + 1]),
nn.ReLU(inplace=True),
nn.Dropout(dp_rate)
]
self.layers = nn.Sequential(*layers)
def forward(self, x, *args, **kwargs):
return self.layers(x)
This code works well. But when I adapt it into Flax, something went wrong:
class CNN(nn.Module):
hidden_size: Sequence[int]
dp_rate: float
training: bool
def setup(self):
layers = []
for idx in range(len(self.hidden_size)):
layers.append(nn.Dense(self.hidden_size[idx]))
self.linear_layers = layers
@nn.compact
def __call__(self, x):
for layer in self.linear_layers:
x = layer(x)
x = nn.relu(x)
x = nn.Dropout(self.dp_rate)(x, deterministic=not self.training)
x = nn.Dense(self.hidden_size[-1])(x)
x = nn.log_softmax(x)
return x
The error message is: 'Incompatible shapes for broadcasting: ((1, 1, 128, 10), (128, 28, 28, 10))' (I used MNIST as my dataset). And it occurs in:
@jax.jit
def train_step(state, imgs, gt_labels, key):
def loss_fn(params):
logits = CNN(training=True, hidden_size = [50,50,10], dp_rate = 0.1).apply(params, imgs, rngs={'dropout': random.PRNGKey(2)})
one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10)
loss = -jnp.mean(jnp.sum(one_hot_gt_labels * logits, axis=-1))
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads) # this is the whole update now! concise!
metrics = compute_metrics(logits=logits, gt_labels=gt_labels) # duplicating loss calculation but it's a bit cleaner
return state, metrics
The size (1, 1, 128, 10), I guess should be the prediction, while (128, 28, 28, 10) should be the input size. I followed the tutorial in Official Documentation (almost the same codes), and I am a little confused about the error.
I shared the document link here: https://colab.research.google.com/drive/1o6_FgW7AO2XvhuM9NGfLMFOWBFOgbF6G?usp=sharing