I'm trying to write a hook that will allow me to compute some global metrics (rather than batch-wise metrics). To prototype, I thought I'd get a simple hook up and running that would capture and remember true positives. It looks like this:
class TPHook(tf.train.SessionRunHook):
def after_create_session(self, session, coord):
print("Starting Hook")
tp_name = 'metrics/f1_macro/TP'
self.tp = []
self.args = session.graph.get_operation_by_name(tp_name)
print(f"Got Args: {self.args}")
def before_run(self, run_context):
print("Starting Before Run")
return tf.train.SessionRunArgs(self.args)
def after_run(self, run_context, run_values):
print("After Run")
print(f"Got Values: {run_values.results}")
However, the values returned in the "after_run" part of the hook are always None. I tested this in both the train and evaluation phase. Am I misunderstanding something about how the SessionRunHooks are supposed to work?
Maybe relevant information:
The model was build in keras and converted to an estimator with the keras.estimator.model_to_estimator()
function. The model has been tested and works fine, and the op that I'm trying to retrieve in the hook is defined in this code block:
def _f1_macro_vector(y_true, y_pred):
"""Computes the F1-score with Macro averaging.
Arguments:
y_true {tf.Tensor} -- Ground-truth labels
y_pred {tf.Tensor} -- Predicted labels
Returns:
tf.Tensor -- The computed F1-Score
"""
y_true = K.cast(y_true, tf.float64)
y_pred = K.cast(y_pred, tf.float64)
TP = tf.reduce_sum(y_true * K.round(y_pred), axis=0, name='TP')
FN = tf.reduce_sum(y_true * (1 - K.round(y_pred)), axis=0, name='FN')
FP = tf.reduce_sum((1 - y_true) * K.round(y_pred), axis=0, name='FP')
prec = TP / (TP + FP)
rec = TP / (TP + FN)
# Convert NaNs to Zero
prec = tf.where(tf.is_nan(prec), tf.zeros_like(prec), prec)
rec = tf.where(tf.is_nan(rec), tf.zeros_like(rec), rec)
f1 = 2 * (prec * rec) / (prec + rec)
# Convert NaN to Zero
f1 = tf.where(tf.is_nan(f1), tf.zeros_like(f1), f1)
return f1
In case anyone runs into the same problem, I found out how to restructure the program so that it worked. Although the documentation makes it sound like I can pass raw ops into the
SessionRunArgs
, it seems like it requires actual tensors (maybe this is a misreading on my part). This is pretty easy to accomplish - I just changed theafter_create_session
code to what's shown below.And this successfully runs.