I am doing a basic orbital mechanics simulation in TensorFlow. When a 'planet' gets too close to the 'sun' (when x,y is close to (0,0)), TensorFlow gets an exception during the division (which could make sense). Somehow it returns an exception during its exception, causing it to fail entirely.
I've tried using tf.where
to conditionally replace these divide by zeros with NaN
, however, it then runs into effectively the same error. I've also tried to use tf.div_no_nan to get zero instead of NaN
, but that gets the exact same error.
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def gravity(state, t):
print(len(tf.unstack(state)))
x, y, vx, vy = tf.unstack(state)
# Error is related to next two lines
fx = -x/tf.pow(tf.reduce_sum(tf.square([x,y]),axis=0),3/2)
fy = -y/tf.pow(tf.reduce_sum(tf.square([x,y]),axis=0),3/2)
dvx = fx
dvy = fy
return tf.stack([vx, vy, dvx, dvy])
# Num simulations
size = 100
# Initialize at same position with varying y-velocity
init_state = tf.stack([tf.constant(-1.0,shape=(size,)),tf.zeros((size)),tf.zeros((size)),tf.range(0,10,.1)])
t = np.linspace(0, 10, num=5000)
tensor_state, tensor_info = tf.contrib.integrate.odeint(
gravity, init_state, t, full_output=True)
init = tf.global_variables_initializer()
with tf.Session() as sess:
state, info = sess.run([tensor_state, tensor_info])
state = tf.transpose(state, perm=[1,2,0]).eval()
x, y, vx, vy = state
for i in range(10):
plt.figure()
plt.plot(x[i], y[i])
plt.scatter([0],[0])
I actually get
...
InvalidArgumentError: assertion failed: [underflow in dt] [9.0294095248318226e-17]
...
During handling of the above exception, another exception occurred:
...
InvalidArgumentError: assertion failed: [underflow in dt] [9.0294095248318226e-17]
...
I would like for the the divide to result in NaN
or Infinity and then propagate that normally as one would expect of numerical integration.
You can try this
I don't know if it's appropriate in your case, but maybe you could just add some small constant to avoid dividing by zero.