Constructing a highly customized neural network in keras (weight sharing, custom connectivity)

565 Views Asked by At

I'm trying to create a NN model for a specific problem in physical sciences, and my motivation is to reduce the number of weights and share weights based on physical insights. The neural net looks something like:

enter image description here

The input size is 2n, where the first n inputs (X1 .. Xn) are fed into the first hiddent layer - however, the connectivity of the neural net is unique in that only one input is fed into the each units of the first hidden layer. Moreover, all the weights are shared among the inputs.

Each unit in the second layer has 2 inputs - one from the previous layer and one directly from the raw input (Xn+1 .. X2n). The weights and biases are shared accordingly.

Finally the output layer has 1 unit with 100 inputs (all outputs from the 2nd hidden layer), and the same weight and bias is applied to each unit.

1

There are 1 best solutions below

2
On

This code will solve your issue (under the assumption that you wanted to add up all the inputs to the output at the end) BTW, if you don't have any activation, the operation you described is linear and can be easily simplified.

import tensorflow.keras as keras
import tensorflow.keras.layers as layers
input_shape = [1]
num_inputs = 4

inputs = [layers.Input(shape=input_shape, name=f"input_{i}") for i in range(num_inputs)]
x = [i for i in inputs]
dense_1 = layers.Dense(units=1, use_bias=False, name="1")
dense_21 = layers.Dense(units=1, use_bias=True, name="21")
dense_22 = layers.Dense(units=1, use_bias=False, name="22")
dense_3 = layers.Dense(units=1, use_bias=True, name="3")

for i in range(num_inputs//2):
    # First hidden layer
    x[i] = dense_1(x[i])
    # Second hidden layer
    x[i] = dense_21(x[i])
    # Connect with the other inputs
    x[i + num_inputs // 2] = dense_22(x[i + num_inputs // 2])
    x[i] = layers.Add()([x[i], x[i + num_inputs // 2]])
    # Last one
    x[i] = dense_3(x[i])

# Add all
x = layers.Add()(x[:num_inputs//2])

model = keras.Model(inputs=inputs,
                    outputs=x)

keras.utils.plot_model(model=model,
                       to_file="model.png",
                       show_shapes=True)

The plot of the above code is: enter image description here