tfcompile of tf.cond of constants errors

197 Views Asked by At

Using the following sample code to create a graph with cond:

from __future__ import absolute_import

import tensorflow as tf

from tensorflow.compiler.tf2xla.tf2xla_pb2 import Config, Feed, Fetch, TensorId
from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto


def tf2xla_config_feed( feed ):

  name = feed.name.split( ':' )[ 0 ]
  pb_id = TensorId( node_name = name )
  pb_dim = [ TensorShapeProto.Dim( size = x.value ) for x in feed.shape ]
  pb_tensor_shape_proto = TensorShapeProto( dim = pb_dim )
  pb_feed = Feed( id = pb_id, shape = pb_tensor_shape_proto )
  return pb_feed


def tf2xla_config_fetch( fetch ):

  name = fetch.name.split( ':' )[ 0 ]
  pb_id = TensorId( node_name = name )
  pb_fetch = Fetch( id = pb_id )
  return pb_fetch


def tf2xla_config( feeds, fetches ):

  pb_feeds = map( tf2xla_config_feed, feeds )
  pb_fetches = map( tf2xla_config_fetch, fetches )
  return Config( feed = pb_feeds, fetch = pb_fetches )


a = tf.placeholder( tf.float64, shape = ( 2, ), name = 'a' )

a1 = a[ 0 ]
a2 = a[ 1 ]

one = tf.constant( 1 )
two = tf.constant( 2 )

res = tf.cond( a1 < a2, lambda: one, lambda: two )

with open( 'test_graph.pb', 'wb' ) as f:
  f.write( res.graph.as_graph_def().SerializeToString() )

with open( 'test_config.pb', 'wb' ) as f:
  f.write( tf2xla_config( [ a ], [ res ] ).SerializeToString() )

And compiling with:

tfcompile --graph=test_graph.pb --config=test_config.pb --entry_point=test_func --cpp_class=test --out_object=test_func.o --out_header=test.hpp

Results in the following error:

2017-11-29 20:40:26.725164: F tensorflow/compiler/aot/tfcompile_main.cc:140] Non-OK-status: status status: Unimplemented: Conversion from TensorFlow graph to XLA resulted in 1 constant results. The configuration of the output args (i.e. fetch ids) is probably wrong.

It seems this error is unwarranted? Or am I doing something wrong?

0

There are 0 best solutions below