Guiding tensorflow keras model training to achieve best Recall At Precision 0.95 for binary classification

483 Views Asked by At

I am hoping to get some help on the titular topic. I have a database of medical data of patients with two similar pathologies, one severe and one much less so. I need flag most of the formers (≥95%) and leave out as many of the latter as possible.

Therefore, I want to create a binary classifier that reflects this. Looking around on the web (not an expert) I put together this piece of code, substituting the metric I found with RecallAtPrecision(0.95) in the middle part of the code. Below is an abridged version:

import tensorflow as tf

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(10, input_dim=x_train.shape[1], activation='relu', kernel_initializer='he_normal'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

model.compile(optimizer='adam', loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.RecallAtPrecision(0.95)])

history = model.fit(x_train, y_train, validation_split=0.33, batch_size=16, epochs=EPOCHS)

However, it simply doesn't work, as it throws the following error:

AttributeError: module 'tensorflow_core.keras.metrics' has no attribute 'RecallAtPrecision'

I am at a loss about why that happened, as I can clearly see it in the documentation. The code works if I use Recall(), Precision() or most any other metrics. Looking around some more, I am beginning to think I am missing something fundamental. Do any of you fine ladies and gentlemen have any pointers on how to solve this problem?

1

There are 1 best solutions below

2
On

To calculate precision and recall, you don't need require Keras. If you have your actual and expected values as vectors of 0/1, you can calculate TP, FP, FN using tf.count_nonzero, you can easily represent them.

TP = tf.count_nonzero(predicted * actual)
FP = tf.count_nonzero(predicted * (actual - 1))
FN = tf.count_nonzero((predicted - 1) * actual)

Your metrics are now simple to calculate:

precision = TP / (TP + FP)
recall = TP / (TP + FN)