I am building model classification using tensorflow-hub, tensorflow-estimators and tensorflow-data.
My train function is returning dataset and the model_fn is defined as follows:
def train_input_fn():
return dataset_input_fn(DATASET_TRAIN_PATH)
def model_fn(features, labels, mode, params):
logging.info("model_fn")
# module is imported from tf-hub
return head.create_estimator_spec (features, mode, ...)
Very similar to the code by Damien.
The code environment is: Python 2, Google cloud datalab, tf.version is 1.12.
The error that is being fired is the model_fn is not expecting labels parameter (which is probably generated by tf-data dataset). What should be the signature of model_fn given that input_fn returns a dataset?
Please advise with any idea.
Many thanks,
eilalan