I am building model classification using tensorflow-hub, tensorflow-estimators and tensorflow-data.
My train function is returning dataset and the model_fn
is defined as follows:
def train_input_fn():
return dataset_input_fn(DATASET_TRAIN_PATH)
def model_fn(features, labels, mode, params):
logging.info("model_fn")
# module is imported from tf-hub
return head.create_estimator_spec (features, mode, ...)
Very similar to the code by Damien.
The code environment is: Python 2, Google cloud datalab, tf.version
is 1.12.
The error that is being fired is the model_fn
is not expecting labels parameter (which is probably generated by tf-data
dataset). What should be the signature of model_fn
given that input_fn
returns a dataset?
Please advise with any idea.
Many thanks,
eilalan