How to build a Pytorch-like code in Jax Flax

649 Views Asked by At

I am trying to build a NN with a dropout layer in case to avoid overfitting. But I met some trouble when I wrote it in Jax Flax.

Here is the original model I built in Pytorch:

class MLPModel(nn.Module):

def __init__(self, layer, dp_rate=0.1):
    super().__init__()
    layers = []
    for idx in range(len(layer) - 1):
        layers += [
            nn.Linear(layer[idx], layer[idx + 1]),
            nn.ReLU(inplace=True),
            nn.Dropout(dp_rate)
        ]
    self.layers = nn.Sequential(*layers)

def forward(self, x, *args, **kwargs):
    return self.layers(x)

This code works well. But when I adapt it into Flax, something went wrong:

class CNN(nn.Module):
hidden_size: Sequence[int]
dp_rate: float
training: bool

def setup(self):
    layers = []
    for idx in range(len(self.hidden_size)):
        layers.append(nn.Dense(self.hidden_size[idx]))
    self.linear_layers = layers
@nn.compact
def __call__(self, x):
    for layer in self.linear_layers:
        x = layer(x)
        x = nn.relu(x)
        x = nn.Dropout(self.dp_rate)(x, deterministic=not self.training)
    x = nn.Dense(self.hidden_size[-1])(x)    
    x = nn.log_softmax(x)
    return x

The error message is: 'Incompatible shapes for broadcasting: ((1, 1, 128, 10), (128, 28, 28, 10))' (I used MNIST as my dataset). And it occurs in:

@jax.jit
def train_step(state, imgs, gt_labels, key):
    def loss_fn(params):
        logits = CNN(training=True, hidden_size = [50,50,10], dp_rate = 0.1).apply(params, imgs, rngs={'dropout': random.PRNGKey(2)})
        one_hot_gt_labels = jax.nn.one_hot(gt_labels, num_classes=10)
        loss = -jnp.mean(jnp.sum(one_hot_gt_labels * logits, axis=-1))
        return loss, logits
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)  # this is the whole update now! concise!
    metrics = compute_metrics(logits=logits, gt_labels=gt_labels)  # duplicating loss calculation but it's a bit cleaner
    return state, metrics

The size (1, 1, 128, 10), I guess should be the prediction, while (128, 28, 28, 10) should be the input size. I followed the tutorial in Official Documentation (almost the same codes), and I am a little confused about the error.

I shared the document link here: https://colab.research.google.com/drive/1o6_FgW7AO2XvhuM9NGfLMFOWBFOgbF6G?usp=sharing

0

There are 0 best solutions below