I am trying to create a simple neural network using flax, as shown below.
However, the params frozen dict I receive as the output to of model.init is empty instead of having the parameters of the neural network. Also the the type(predictions) is flax.linen.combinators.Sequential object instead of being a DeviceArray.
Can someone help me understand what is wrong with this code snippet?
import jax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        return nn.Sequential(
            [
                nn.Dense(40),
                nn.relu,
                nn.Dense(40),
                nn.Dense(1),
            ]
        )
model = MLP()
dummy_input = jnp.ones((40, 40, 1))
params = model.init(jax.random.PRNGKey(0), dummy_input)
jax.tree_util.tree_map(lambda x: x.shape, params)
n = 100
x_inputs = jnp.linspace(-10, 10, n).reshape(1, -1)
y_targets = jnp.sin(x_inputs)
predictions = model.apply(params, x_inputs)
plt.plot(x_inputs.reshape(-1), y_targets.reshape(-1))
plt.plot(x_inputs.reshape(-1), predictions.reshape(-1))
				
                        
The problem is that
nn.Sequentialreturns a function that needs to be called with input. Replacingwith
Solves the problem.