Exception handling an exception during numerical integration

398 Views Asked by At

I am doing a basic orbital mechanics simulation in TensorFlow. When a 'planet' gets too close to the 'sun' (when x,y is close to (0,0)), TensorFlow gets an exception during the division (which could make sense). Somehow it returns an exception during its exception, causing it to fail entirely.

I've tried using tf.where to conditionally replace these divide by zeros with NaN, however, it then runs into effectively the same error. I've also tried to use tf.div_no_nan to get zero instead of NaN, but that gets the exact same error.

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def gravity(state, t):
    print(len(tf.unstack(state)))
    x, y, vx, vy = tf.unstack(state)
    # Error is related to next two lines
    fx = -x/tf.pow(tf.reduce_sum(tf.square([x,y]),axis=0),3/2)
    fy = -y/tf.pow(tf.reduce_sum(tf.square([x,y]),axis=0),3/2)
    dvx = fx
    dvy = fy
    return tf.stack([vx, vy, dvx, dvy])

# Num simulations
size = 100

# Initialize at same position with varying y-velocity
init_state = tf.stack([tf.constant(-1.0,shape=(size,)),tf.zeros((size)),tf.zeros((size)),tf.range(0,10,.1)])

t = np.linspace(0, 10, num=5000)
tensor_state, tensor_info = tf.contrib.integrate.odeint(
    gravity, init_state, t, full_output=True)

init = tf.global_variables_initializer()
with tf.Session() as sess:   
    state, info = sess.run([tensor_state, tensor_info])
    state = tf.transpose(state, perm=[1,2,0]).eval()

x, y, vx, vy = state
for i in range(10):
    plt.figure()
    plt.plot(x[i], y[i])
    plt.scatter([0],[0])

I actually get

...
InvalidArgumentError: assertion failed: [underflow in dt] [9.0294095248318226e-17]
...
During handling of the above exception, another exception occurred:
...
InvalidArgumentError: assertion failed: [underflow in dt] [9.0294095248318226e-17]
...

I would like for the the divide to result in NaNor Infinity and then propagate that normally as one would expect of numerical integration.

1

There are 1 best solutions below

0
On

You can try this

with tf.Session() as sess:
    sess.run(init)
    try:
        state, info = sess.run([tensor_state, tensor_info])
    except tf.errors.InvalidArgumentError:
        state = #Whatever values/shape you need

I don't know if it's appropriate in your case, but maybe you could just add some small constant to avoid dividing by zero.