How to get PR Curve in Object Detection API (Tensorflow)?

136 Views Asked by At

Currently, I have a single object detection model trained for 1 class with a mAP score of 0.87, I trained it using the model_main_tf2.py script provided in TensorFlow.

Now I'm having a hard time setting the threshold when doing inference with this model. So after a little research, I understood that I need the Precision-Recall Curve so I can select the appropriate threshold depending on my application (precision-recall tradeoff). So far so good.

But now.. how can I get the PR Curve? After making validation, the only stuff that I get from the console is the mAP and AR values (from COCO detection metrics), but nothing else, even in tensorboard there's nothing showing PR-Curve (there is an option, but it shows empty).

I ran validation using the script provided in model_main_tf2.py, that is:

model_lib_v2.eval_continuously(
    pipeline_config_path=FLAGS.pipeline_config_path,
    model_dir=FLAGS.model_dir,
    train_steps=FLAGS.num_train_steps,
    sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples,
    sample_1_of_n_eval_on_train_examples=(
        FLAGS.sample_1_of_n_eval_on_train_examples
    ),
    checkpoint_dir=FLAGS.checkpoint_dir,
    wait_interval=300,
  )           

I've already looked and found this PR Curve plugin, but the problem is that its only for image classification (not object detection, as I intend to)

On the other hand there is this thread on github, but the code that's proposed in there seems from TensorFlow 1.

Hope anyone can point me in the right direction!

P.S1: In order to get the mAP values, the eval_continuisly(...) script should have already computed the PR Curve, so the calculation must be in some part of the code, but I cannot seem to find it, so I can print it out :S P.S2: Maybe I should switch to YOLO? since object detection API is deprecated now :S

0

There are 0 best solutions below