I am trying to create a neural network for policy based RL. I have wrote the class to build the network and generate actions as below:
class Oracle(object):
def __init__(self, input_dim, output_dim, hidden_dims=None):
if hidden_dims is None:
hidden_dims = [32, 32]
self.input_dim = input_dim
self.output_dim = output_dim
self.__build_network(input_dim,output_dim,hidden_dims)
self.__build_train_fn()
def __build_network(self,input_dim, output_dim, hidden_dims):
"""Create a base network"""
inputs = Input(shape=(input_dim,))
net = inputs
# a layer instance is callable on a tensor, and returns a tensor
for h_dim in hidden_dims:
net = Dense(h_dim, activation='relu',kernel_initializer='RandomNormal',bias_initializer='zeros')(net)
net = Dense(output_dim, activation='softmax',kernel_initializer='RandomNormal',bias_initializer='zeros')(net)
# This creates a model that includes
# the Input layer and three Dense layers
self.model = Model(inputs=inputs, outputs=net)
return self.model
def __build_train_fn(self):
"""Create a train function
It replaces `model.fit(X, y)` because we use the output of model and use it for training.
For example, we need action placeholder
called `action_one_hot` that stores, which action we took at state `s`.
Hence, we can update the same action.
This function will create
`self.train_fn([state, action_one_hot, discount_reward])`
which would train the model.
"""
action_prob_placeholder = self.model.output
action_onehot_placeholder = K.placeholder(shape=(None, self.output_dim),
name="action_onehot")
discount_reward_placeholder = K.placeholder(shape=(None,),
name="discount_reward")
action_prob = K.sum(action_prob_placeholder * action_onehot_placeholder, axis=1)
log_action_prob = K.log(action_prob)
loss = - log_action_prob * discount_reward_placeholder
loss = K.mean(loss)
adam = optimizers.Adam()
updates = adam.get_updates(params=self.model.trainable_weights,
constraints=[],
loss=loss)
self.train_fn = K.function(inputs=[self.model.input,
action_onehot_placeholder,
discount_reward_placeholder],
outputs=[],
updates=updates)
def get_action(self, state):
"""Returns an action at given `state`
Args:
state (1-D or 2-D Array): It can be either 1-D array of shape (state_dimension, )
or 2-D array shape of (n_samples, state_dimension)
Returns:
action: an integer action value ranging from 0 to (n_actions - 1)
"""
shape = state.shape
if len(shape) == 1:
assert shape == (self.input_dim,), "{} != {}".format(shape, self.input_dim)
state = np.expand_dims(state, axis=0)
elif len(shape) == 2:
assert shape[1] == (self.input_dim), "{} != {}".format(shape, self.input_dim)
else:
raise TypeError("Wrong state shape is given: {}".format(state.shape))
action_prob = np.squeeze(self.model.predict(state))
assert len(action_prob) == self.output_dim, "{} != {}".format(len(action_prob), self.output_dim)
print(state)
print(state.shape)
weights = self.model.get_weights()
print(weights)
return np.random.choice(np.arange(self.output_dim), p=action_prob)
I want to use this in a policy based RL. The problem is even though I initialize weights as Random normal
(or other initializers), the weights output has a lot of nans. Also, the action_prob
is also coming out as nan. The representative output for weights is given below. Can anyone please let me know how this can be remedied?
[array([[ 1.97270699e-02, nan, -1.53264655e-02,
nan, nan, 9.83271226e-02,
nan, 1.67111661e-02, nan,
-5.40489666e-02, nan, -3.19434591e-02,
nan, -8.62319861e-03, nan,
3.90832238e-02, nan, nan,
nan, -3.34417708e-02, nan,
4.17598374e-02, 1.23961531e-02, 1.13383524e-01,
1.52971387e-01, -7.35234842e-02, 4.81316447e-03,
nan, nan, 9.02018696e-02,
-5.64984754e-02, nan],
[ 3.42946462e-02, nan, -2.32576765e-02,
nan, nan, -1.62454545e-02,
nan, 7.62931630e-02, nan,
7.09382221e-02, nan, -9.45277140e-02,
nan, 6.81431815e-02, nan,
5.43346964e-02, nan, nan,
nan, -5.25366806e-04, nan,
-3.03930230e-02, 1.90449376e-02, -6.84814155e-02,
-4.24950942e-02, -4.82842028e-02, 3.00289365e-03,
nan, nan, 1.14762083e-01,
-1.53483404e-02, nan],
[ 1.11763954e-01, nan, -2.40741558e-02,
nan, nan, -2.25515720e-02,
nan, 8.37199837e-02, nan,
8.01791809e-03, nan, 4.11959179e-02,
nan, -8.09677169e-02, nan,
1.09827537e-02, nan, nan,
nan, 3.24306265e-03, nan,
-4.61481474e-02, -4.44600247e-02, 5.97798042e-02,
-2.80357362e-03, 4.99138907e-02, -3.16888206e-02,
nan, nan, 4.79343869e-02,
-3.04902103e-02, nan],
[ 9.96000832e-04, nan, 7.03881904e-02,
nan, nan, 3.29129435e-02,
nan, 2.59399302e-02, nan,
3.94702554e-02, nan, 5.41977606e-05,
nan, -8.05872083e-02, nan,
7.35593066e-02, nan, nan,
nan, -3.20138596e-02, nan,
-4.88653146e-02, -3.05510052e-02, 1.61004122e-02,
3.60239707e-02, -2.89578568e-02, -8.55704099e-02,
nan, nan, -4.69469689e-02,
5.44301942e-02, nan],
[ 2.39880346e-02, nan, 1.02485856e-02,
nan, nan, -3.28975841e-02,
nan, 3.20423655e-02, nan,
7.26358453e-03, nan, -3.04405931e-02,
nan, 1.31638274e-02, nan,
-6.58982694e-02, nan, nan,
nan, -8.48279800e-03, nan,
5.07000796e-02, -3.43187563e-02, 1.69583317e-02,
5.02665602e-02, 6.59292564e-02, 5.91163523e-03,
nan, nan, 1.64841004e-02,
1.03674673e-01, nan],
[ 2.22617369e-02, nan, -9.83130708e-02,
nan, nan, -8.62144455e-02,
nan, -1.24993315e-03, nan,
-3.39315496e-02, nan, -3.71638462e-02,
nan, -2.51251217e-02, nan,
-3.30121554e-02, nan, nan,
nan, 6.95239231e-02, nan,
3.96330692e-02, -7.67886639e-02, 3.19798961e-02,
-7.02575818e-02, 5.36917103e-03, -7.84784183e-02,
nan, nan, -1.12238321e-02,
5.90852983e-02, nan],
[ -1.23783462e-02, nan, 8.54373630e-03,
nan, nan, 2.71492247e-02,
nan, -4.39056493e-02, nan,
1.54177221e-02, nan, 8.08294937e-02,
nan, -2.47991290e-02, nan,
-4.90374281e-04, nan, nan,
nan, -2.03785431e-02, nan,
-2.94432435e-02, -4.85701524e-02, -5.98664656e-02,
5.03640659e-02, -1.06101505e-01, -5.01858108e-02,
nan, nan, 1.59794372e-02,
-5.52875735e-03, nan],
[ -6.50038645e-02, nan, -2.88410280e-02,
nan, nan, 5.70952846e-03,
nan, 2.29494330e-02, nan,
2.96308636e-03, nan, -1.30019784e-02,
nan, 1.38891954e-02, nan,
9.82243866e-02, nan, nan,
nan, -4.53725718e-02, nan,
7.28782360e-03, -1.97060239e-02, 1.30356764e-02,
-1.77630689e-02, -5.27498014e-02, -5.70283793e-02,
nan, nan, -4.40920331e-03,
-8.47700890e-03, nan],
[ -7.09274644e-03, nan, -2.85792332e-02,
nan, nan, 1.90456193e-02,
nan, 2.33339947e-02, nan,
-7.10851625e-02, nan, -2.07360443e-02,
nan, -8.23910628e-03, nan,
1.53461788e-02, nan, nan,
nan, 8.74896254e-03, nan,
-1.04130013e-02, -8.23952537e-03, 3.29020806e-02,
-8.53802171e-03, -5.38858548e-02, 2.94392351e-02,
nan, nan, 2.28152424e-03,
3.86046581e-02, nan],
[ 6.32084534e-02, nan, 1.79775548e-03,
nan, nan, -5.96092641e-02,
nan, 1.74504239e-03, nan,
9.05414373e-02, nan, -3.55534554e-02,
nan, -3.89753282e-02, nan,
8.71098042e-03, nan, nan,
nan, 7.47531727e-02, nan,
5.26362322e-02, 1.46157984e-02, 3.21042910e-03,
-7.87475239e-03, 4.22325032e-03, 1.58537421e-02,
nan, nan, 3.45352525e-03,
9.88092553e-03, nan],
[ 8.60697851e-02, nan, 7.76077956e-02,
nan, nan, 1.35996595e-01,
nan, 7.12691769e-02, nan,
-2.70256456e-02, nan, 9.95257962e-03,
nan, -2.21844148e-02, nan,
4.18028049e-02, nan, nan,
nan, 6.15538433e-02, nan,
-3.34422104e-02, 7.96959698e-02, 3.36392457e-03,
-9.79953539e-03, 1.52911739e-02, -9.56133530e-02,
nan, nan, 3.26185785e-02,
-5.18142292e-03, nan],
[ -7.14878365e-02, nan, 3.30364555e-02,
nan, nan, -7.56359026e-02,
nan, -8.38122815e-02, nan,
3.50784622e-02, nan, 6.51308149e-02,
nan, -8.44882503e-02, nan,
1.97267421e-02, nan, nan,
nan, -4.02851999e-02, nan,
-3.84002179e-02, 3.23568434e-02, 9.30055231e-03,
2.97283176e-02, -3.93995969e-03, 1.24160219e-02,
nan, nan, -5.86424842e-02,
-5.61306179e-02, nan],
[ 5.52838258e-02, nan, -2.10575890e-02,
nan, nan, -1.46265700e-02,
nan, -6.19944222e-02, nan,
-4.26368900e-02, nan, -1.77203845e-02,
nan, 7.23404884e-02, nan,
1.19749429e-02, nan, nan,
nan, -1.97013188e-02, nan,
-9.93668661e-03, -1.43543081e-02, -1.89676192e-02,
-3.46484780e-02, -2.41095871e-02, 2.64016148e-02,
nan, nan, 3.39512643e-03,
-2.40868814e-02, nan],
[ 4.85769324e-02, nan, -2.96661835e-02,
nan, nan, -1.16411140e-02,
nan, -9.32439044e-03, nan,
-2.47888379e-02, nan, -2.11149845e-02,
nan, 1.55771989e-02, nan,
-3.60703245e-02, nan, nan,
nan, -8.21380615e-02, nan,
7.12675974e-02, 3.52902263e-02, 5.15214726e-03,
4.55725230e-02, -3.67484652e-02, -1.13544762e-02,
nan, nan, -3.86700444e-02,
-3.91620398e-02, nan],
[ -5.83947077e-03, nan, 5.90741597e-02,
nan, nan, -4.57256138e-02,
nan, -8.41458961e-02, nan,
-7.60969743e-02, nan, 2.50754189e-02,
nan, 2.75974572e-02, nan,
2.27455739e-02, nan, nan,
nan, -1.64209884e-02, nan,
-2.64473110e-02, -1.31150903e-02, 3.04512922e-02,
-5.81411598e-03, 1.68283712e-02, -1.44851422e-02,
nan, nan, -2.56322809e-02,
1.11139610e-01, nan],
[ 8.34780037e-02, nan, 6.61360845e-03,
nan, nan, -1.08085848e-01,
nan, -1.87303626e-03, nan,
-2.97805574e-02, nan, -4.96098958e-02,
nan, -2.47526560e-02, nan,
5.78494631e-02, nan, nan,
nan, 9.74192936e-03, nan,
-4.88330796e-02, 1.02368537e-02, -2.99407393e-02,
-3.94638889e-02, -1.45375028e-01, -8.38985574e-03,
nan, nan, -2.59864815e-02,
-5.39724007e-02, nan],
[ 2.34477259e-02, nan, 6.47758618e-02,
nan, nan, -2.06562635e-02,
nan, -1.50227742e-02, nan,
-4.99106087e-02, nan, -8.75398964e-02,
nan, -1.91738885e-02, nan,
9.81663391e-02, nan, nan,
nan, 8.30503032e-02, nan,
-6.02204986e-02, -5.43463342e-02, -2.73545366e-02,
-3.97464111e-02, -1.08450698e-03, 1.27358735e-02,
nan, nan, -6.65350258e-02,
-7.63151273e-02, nan],
[ -1.75849702e-02, nan, 5.18983677e-02,
nan, nan, 2.52664816e-02,
nan, -7.14112073e-03, nan,
2.89890468e-02, nan, -3.46427821e-02,
nan, 1.85990240e-02, nan,
-4.50296048e-03, nan, nan,
nan, -5.50862215e-02, nan,
1.02454759e-01, 9.34040993e-02, 1.45452050e-02,
2.90963929e-02, 3.19026299e-02, 1.89037640e-02,
nan, nan, -1.68684160e-03,
9.94853582e-03, nan],
[ -9.39413719e-03, nan, -3.46053950e-03,
nan, nan, 3.13128680e-02,
nan, -2.45536752e-02, nan,
4.08208035e-02, nan, 2.67537422e-02,
nan, 8.34849998e-02, nan,
-2.65908819e-02, nan, nan,
nan, -2.63154972e-03, nan,
4.54281829e-02, 1.24697601e-02, 5.25561944e-02,
5.75856939e-02, -8.61058664e-03, 2.86082458e-02,
nan, nan, -4.48538922e-02,
6.58497736e-02, nan],
[ -4.35961820e-02, nan, 5.22863083e-02,
nan, nan, -8.59688129e-03,
nan, -5.25927730e-02, nan,
7.24843144e-02, nan, -4.00458984e-02,
nan, -2.85069328e-02, nan,
2.43122727e-02, nan, nan,
nan, 1.57326814e-02, nan,
4.99758229e-04, 1.23931235e-02, 1.90575924e-02,
-4.64425469e-03, 5.54191284e-02, 2.38004271e-02,
nan, nan, -7.39056617e-03,
3.59723084e-02, nan],
[ 6.80808276e-02, nan, -1.49172200e-02,
nan, nan, -1.84247848e-02,
nan, 7.11160824e-02, nan,
4.74170335e-02, nan, -8.48565064e-03,
nan, 6.96734041e-02, nan,
1.07453577e-01, nan, nan,
nan, 3.21782194e-02, nan,
3.53086367e-02, -2.57775784e-02, -3.70149538e-02,
8.49922895e-02, 4.88188267e-02, 4.43161186e-03,
nan, nan, 7.35458219e-03,
-4.75145914e-02, nan],
[ -1.23953104e-01, nan, -4.27762084e-02,
nan, nan, 2.04169434e-02,
nan, 5.78987077e-02, nan,
-6.60712123e-02, nan, -2.07597148e-02,
nan, 3.00809499e-02, nan,
1.40863642e-01, nan, nan,
nan, -4.05914113e-02, nan,
-4.87232655e-02, 1.49445562e-02, 3.01859360e-02,
2.01087426e-02, 7.96428975e-03, 2.58545913e-02,
nan, nan, -3.26734572e-03,
2.30945610e-02, nan]], dtype=float32), array([ 0., nan, 0., nan, nan, 0., nan, 0., nan, 0., nan,
0., nan, 0., nan, 0., nan, nan, nan, 0., nan, 0.,
0., 0., 0., 0., 0., nan, nan, 0., 0., nan], dtype=float32), array([[ nan, nan, nan, ..., nan,
nan, 0.08562656],
[ nan, nan, nan, ..., nan,
nan, -0.03227361],
[ nan, nan, nan, ..., nan,
nan, -0.1371294 ],
...,
[ nan, nan, nan, ..., nan,
nan, 0.01600872],
[ nan, nan, nan, ..., nan,
nan, -0.0156843 ],
[ nan, nan, nan, ..., nan,
nan, -0.036583 ]], dtype=float32), array([ nan, nan, nan, nan, nan, nan, 0., 0., nan, 0., 0.,
0., 0., 0., nan, nan, nan, 0., nan, 0., 0., 0.,
nan, 0., nan, nan, nan, nan, nan, nan, nan, 0.], dtype=float32), array([[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan],
[ nan, nan, nan]], dtype=float32), array([ nan, nan, nan], dtype=float32)]
I was facing the same problem. While I was trying to implement 3 layers of GRU stacked together using Keras, I found out that every time one of the layers had
nan
values. Thus even the loss calculated was nan. The initializer was'glorot_uniform'
. I couldn't solve the problem for a while. But recently when I updated my keras and tensorflow using the commands -the problem was solved and then I was able to reduce my loss to approximately 1.3 on MSCOCO dataset.
The issue might be much more deeper than compatibility with different versions. But doing this helped me and thought might help you.