I want to implement a tf.Module for decoding box predictions and applying NonMaxSuppression that is convertible to tflite.
This implementation includes elements from here.
It also follows this guide on operation fusion.
This is my Code:
def decode_predictions_lite(anchor_boxes: tf.Tensor, box_pred: tf.Tensor, cls_pred: tf.Tensor, variance: tf.Tensor, image_shape):
    """a concrete function that decodes box and class predictions and applies NonMaxSuppression
    Args:
        anchor_boxes: tf.Tensor of shape [1, N, 4] representing anchor_boxes in the 'center_yxhw' format
        box_pred: tf.Tensor of shape [b, N, 4] representing encoded box predictions from the model's output
        cls_pred: tf.Tensor of shape [b, N, num_classes] representing the class logits from the model's output
        variance: tf.Tensor of shape [4] representing the box variance that was used when encoding the boxes
        image_shape: tf.Tensor of shape [2] representing the height and width of the input image
    Returns: 
        pass
    Note: 
        - N is the number of anchor boxes. All input tensors are of dtype float32
        - b is the batch_size. For now, only a batch_size of 1 is supported.
    """
    scores = tf.sigmoid(cls_pred)
    boxes = box_pred*variance
    decoded_boxes = tf.concat(
        [
            boxes[..., :2] * anchor_boxes[..., :2] + anchor_boxes[..., :2],
            tf.math.exp(boxes[..., :2]) * anchor_boxes[..., :2]
        ], axis=-1
    )
    
    # Normalize anchor coordinates for TFLite's NMS operation.
    normalize_factor = tf.tile(image_shape, [2])
    anchor_boxes = anchor_boxes / normalize_factor
    anchor_boxes = tf.squeeze(anchor_boxes) # squeeze so the anchor_boxes are of shape (N, 4)
    
    # normalize box coordinates for TFLite's NMS operation
    decoded_boxes_rel = decoded_boxes / normalize_factor
    def get_implements_signature():
        implements_signature = ' '.join([
        'name: "%s"' % 'TFLite_Detection_PostProcess',
        'attr { key: "max_detections" value { i: %d } }' % 100,
        'attr { key: "max_classes_per_detection" value { i: %d } }' % 1,
        'attr { key: "detections_per_class" value { i: %d } }' % 5,
        'attr { key: "use_regular_nms" value { b: %s } }' % "false", # Lower
        'attr { key: "nms_score_threshold" value { f: %f } }' % 0.1,
        'attr { key: "nms_iou_threshold" value { f: %f } }' % 0.5,
        'attr { key: "y_scale" value { f: %f } }' % 1.0,
        'attr { key: "x_scale" value { f: %f } }' % 1.0,
        'attr { key: "h_scale" value { f: %f } }' % 1.0,
        'attr { key: "w_scale" value { f: %f } }' % 1.0,
        'attr { key: "num_classes" value { i: %d } }' % num_classes,
        ])
        return implements_signature
    
    @tf.function(experimental_implements=get_implements_signature())
    def dummy_postprocessing_nms(input_boxes, input_scores, input_anchors):
        boxes = tf.constant(0.0, dtype=tf.float32, name='boxes')
        scores = tf.constant(0.0, dtype=tf.float32, name='scores')
        classes = tf.constant(0.0, shape=(1, 100), dtype=tf.float32, name='classes')
        num_detections = tf.constant(0.0, dtype=tf.float32, name='num_detections')
        return boxes, classes, scores, num_detections
    
    return decoded_boxes, dummy_postprocessing_nms(decoded_boxes_rel, scores, anchor_boxes)
class PredictionDecoderLite(tf.Module):
    def __init__(self):
        super(PredictionDecoderLite, self).__init__()
    
    @tf.function(input_signature=[
        tf.TensorSpec(shape=[1, None, 4], dtype=tf.float32),  # anchor_boxes
        tf.TensorSpec(shape=[1, None, 4], dtype=tf.float32),  # box_pred
        tf.TensorSpec(shape=[1, None, None], dtype=tf.float32),  # cls_pred
        tf.TensorSpec(shape=[4], dtype=tf.float32),  # variance
        tf.TensorSpec(shape=[2], dtype=tf.float32)   # image_shape
    ])
    def decode_preds_lite(self, anchor_boxes, box_pred, cls_pred, variance, image_shape):
        return decode_predictions_lite(anchor_boxes, box_pred, cls_pred, variance, image_shape)
    
decoder_module = PredictionDecoderLite()
concrete_fn_with_nms = decoder_module.decode_preds_lite.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_fn_with_nms], trackable_obj=decoder_module)
tflite_pred_decoder = converter.convert()
with open(tflite_postprocess_path, 'wb') as f:
    f.write(tflite_pred_decoder)
The issue is that the converter does not rewrites the dummy function to TFLite's custom NMS operation. I tested the tflite module with the following code:
image_shape = [HEIGHT, WIDTH, 3]
variance = tf.constant([1.0, 1.0, 1.0, 1.0])
anchor_generator = kcv.models.RetinaNet.default_anchor_generator(bounding_box_format)
anchors = anchor_generator(image_shape=image_shape)
anchors = ops.concatenate([a for a in anchors.values()], axis=0)
anchors = tf.expand_dims(anchors, axis=0)
print("anchors: ", anchors.shape)   # anchors:  (1, 76725, 4)
print("boxes: ", predictions['box'].shape)  # boxes:  (1, 76725, 4)
print("scores: ", predictions['classification'].shape)  # scores:  (1, 76725, 4)
interpreter = tf.lite.Interpreter(tflite_postprocess_path)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.resize_tensor_input(input_details[0]['index'], anchors.shape)
interpreter.resize_tensor_input(input_details[1]['index'], predictions['box'].shape)
interpreter.resize_tensor_input(input_details[2]['index'], predictions['classification'].shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], anchors.numpy())
interpreter.set_tensor(input_details[1]['index'], predictions['box'])
interpreter.set_tensor(input_details[2]['index'], predictions['classification'])
interpreter.set_tensor(input_details[3]['index'], variance.numpy())
interpreter.set_tensor(input_details[4]['index'], np.array(image_shape[:2], dtype='float32'))
interpreter.invoke()
# Retrieve outputs:
decoded_boxes = interpreter.get_tensor(output_details[0]['index'])
final_boxes = interpreter.get_tensor(output_details[1]['index'])
final_scores = interpreter.get_tensor(output_details[2]['index'])
final_classes = interpreter.get_tensor(output_details[3]['index'])
num_detections = interpreter.get_tensor(output_details[4]['index'])
I received zeros from my dummy function and no post-processed boxes. I don't know how to get more information out of the converter to see, what is wrong. Does anyone have an idea, how to proceed?
 
                        
So... After almost two days of debugging, I finally found a workaround. Instead of using
tf.lite.TFLiteConverter.from_concrete_functions()we first save the module and then usetf.lite.TFLiteConverter.from_saved_model(). In the tensorflow model garden, they do it the same way (See here).There were also some other issues in the code above so here is an updated version. You can directly use the box and class predictions from a
keras_cv.models.RetinaNetthat was converted to tflite.We can test the module the same way as before. Note: the order of input and output tensors is a bit strange.
It is worth noting that TFLite's custom NMS operation does the box-decoding for you (See here). So far I have not found any documentation about this op. It would have made thinks much easier.