how to make sense of tensorflowjs object detection tensor output?

863 Views Asked by At

My motivation is to build a custom objection detection web application. I downloaded a tf2 pretrained SSD Resnet1010 model from model zoo. My idea is if this implementation works, I will train the model with my own data. I ran $saved_model_cli show --dir saved_model --tag_set serve --signature_def serving_default to figure out input and output nodes.

The given SavedModel SignatureDef contains the following input(s):
  inputs['input_tensor'] tensor_info:
      dtype: DT_UINT8
      shape: (1, -1, -1, 3)
      name: serving_default_input_tensor:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['detection_anchor_indices'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100)
      name: StatefulPartitionedCall:0
  outputs['detection_boxes'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100, 4)
      name: StatefulPartitionedCall:1
  outputs['detection_classes'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100)
      name: StatefulPartitionedCall:2
  outputs['detection_multiclass_scores'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100, 91)
      name: StatefulPartitionedCall:3
  outputs['detection_scores'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 100)
      name: StatefulPartitionedCall:4
  outputs['num_detections'] tensor_info:
      dtype: DT_FLOAT
      shape: (1)
      name: StatefulPartitionedCall:5
  outputs['raw_detection_boxes'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 51150, 4)
      name: StatefulPartitionedCall:6
  outputs['raw_detection_scores'] tensor_info:
      dtype: DT_FLOAT
      shape: (1, 51150, 91)
      name: StatefulPartitionedCall:7
Method name is: tensorflow/serving/predict

Then I converted the model to tensorflowjs model, by running

tensorflowjs_converter --input_format=tf_saved_model --output_node_names='detection_anchor_indices,detection_boxes,detection_classes,detection_multiclass_scores,detection_scores,num_detections,raw_detection_boxes,raw_detection_scores' --saved_model_tags=serve --output_format=tfjs_graph_model saved_model js_model

Here is my javascript code (this goes inside vue methods)

    loadTfModel: async function(){
        try {
            this.model = await tf.loadGraphModel(this.MODEL_URL);
        } catch(error) {
            console.log(error);
        }

   },
    predictImg: async function() {
        const imgData = document.getElementById('img');
        let tf_img = tf.browser.fromPixels(imgData);
        tf_img = tf_img.expandDims(0);
        const predictions = await this.model.executeAsync(tf_img);
        const data = []
        for (let i = 0; i < predictions.length; i++){
            data.push(predictions[i].dataSync());
        }
        console.log(data);
    }

The output looks like this: Screenshot

My question is does these eight items in the array corresponds to eight defined output nodes? How to make sense of this data? and how to convert this into a human-readable format like the python one?
Update 1: I have tried this answer and edited my predict method:

predictImg: async function() {
        const imgData = document.getElementById('img');
        let tf_img = tf.browser.fromPixels(imgData);
        tf_img = tf_img.expandDims(0);
        const predictions = await this.model.executeAsync(tf_img, ['detection_classes']).then(predictions => {
            const data = predictions.dataSync()
            console.log('Predictions: ', data);
        })

    }

I ended up getting, "Error: The output 'detection_classes' is not found in the graph". I would appreciate any help.

2

There are 2 best solutions below

0
On

There might probably a mistake in the output node specified in this.model.executeAsync(tf_img, ['detection_classes']). Additionnally, there is no need to use await here await this.model.executeAsync(tf_img, ['detection_classes']). Either await is used or then is used.

The other option to get the detection_classes is to index the array of output:

predictions[i].dataSync()[2]
0
On

I think you first need to check the web_model/model.json file and investigate the name of the outputs. Those are the ones you will need to use when filtering what to display (the following is my example file).

enter image description here