How to define dynamic-shape variable when building computational graph with Tensorflow 1.15

78 Views Asked by At

System information

  1. Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No

  2. OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 18.04

  3. Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:

  4. TensorFlow installed from (source or binary): Conda repo

  5. TensorFlow version (use command below): 1.15

  6. Python version: 3.7.7

  7. Bazel version (if compiling from source):

  8. GCC/Compiler version (if compiling from source):

  9. CUDA/cuDNN version: 10.1

  10. GPU model and memory: Tesla V100-SMX3-32GB

  11. Describe the current behavior

    tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [] rhs shape= [1,1] [[{{node Variable/Assign}}]]

Describe the expected behavior

No error

Standalone code to reproduce the issue

import tensorflow as tf
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'

with tf.Session() as sess:
    v = tf.Variable(np.zeros(shape=[1,1]),shape=tf.TensorShape(None))
    sess.run(tf.global_variables_initializer())

Obseration: The error did not appear when I use eager_execution_mode()

Code:

tf.enable_eager_execution()
v = tf.Variable(np.zeros([1,1]),shape=tf.TensorShape(None))
tf.print(v)
v.assign(np.ones([2,2]))
tf.print(v)    

Output:

[[0]]
[[1 1]
 [1 1]]

Link to a MWE: https://colab.research.google.com/gist/amahendrakar/3fe8345db4092d520246205be4b97948/41620.ipynb

1

There are 1 best solutions below

0
lengoanhcat On

Just have to enable resource variable as dynamic shape behavior is only available for this 'updated' Variable class.

import tensorflow as tf
import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'
tf.compat.v1.enable_resource_variables()

with tf.Session() as sess:
    v = tf.Variable(np.zeros(shape=[1,1]),shape=tf.TensorShape(None))
    sess.run(tf.global_variables_initializer())