Why does my tflite model predict well with Python interpreter, but very poorly when deployed in Android Studio?

4.2k Views Asked by At

I have built a model using mobilenet for a specific training data set. When testing my model with a test set, the model generated in keras (model.h5) obtains an accuracy of approximately 92%. Then I converted my model to tflite with the following code:

model = tf.keras.models.load_model('modelos TensorflowLite/MobileNet.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("MobileNet.tflite", "wb").write(tflite_model)

When executing the tflite model with respect to the same test set using the tflite interpreter in python, I obtain an accuracy very similar to that obtained with the keras model, close to 92%. Code used for one inference in interpreter:

interpreter = tf.lite.Interpreter(model_path="MobileNet.tflite")
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.allocate_tensors()

    
    # lectura y procesamiento de imagen
    img = cv2.imread("image.jpg")
    new_img = cv2.resize(img, (300, 300))
    new_img = new_img.astype(np.float32)
    new_img /= 255.
    
    # input_details[0]['index'] = the index which accepts the input
    interpreter.set_tensor(input_details[0]['index'], [new_img])
    
    # realizar la prediccion del interprete
    interpreter.invoke()
    
    # output_details[0]['index'] = the index which provides the input
    output_data = interpreter.get_tensor(output_details[0]['index'])
    
    print("For file {}, the output is {}".format(file.stem, output_data))

The problem appears when I test the test suite in android studio. The accuracy against the same test set is 39% using the same model converted to tflite. It should be mentioned that the model is not quantified. I performed a single image comparison of the results obtained for each of the 3 classes. In this image the class was correctly classified with keras and tflite model, but not in android:

probability keras model .h5 tflite py interpreter tflite android
prob. correct class 9.6e-01 9.6e-01 3.2e-6

My problem is not that the accuracy is low when converting the .h5 model to .tflite. My problem is that the tflite model works correctly in the python interpreter, but very badly when implementing it in android studio.

Code that loads the image:

private TensorImage loadImage(Bitmap bitmap, int sensorOrientation) {
    // Loads bitmap into a TensorImage.
    inputImageBuffer.load(bitmap);

    int noOfRotations = sensorOrientation / 90;
    int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight());

    ImageProcessor imageProcessor = new ImageProcessor.Builder()
            .add(new ResizeWithCropOrPadOp(cropSize, cropSize))
            .add(new ResizeOp(imageResizeX, imageResizeY, ResizeOp.ResizeMethod.BILINEAR))
            .add(new Rot90Op(noOfRotations))
            .add(new NormalizeOp(IMAGE_MEAN, IMAGE_STD))
            .build();
    return imageProcessor.process(inputImageBuffer);
}

Code that performs the prediction:

inputImageBuffer = loadImage(bitmap, sensorOrientation);
tensorClassifier.run(inputImageBuffer.getBuffer(), probabilityImageBuffer.getBuffer().rewind());

All the code to classify (ImageCLassifier.java):

import android.app.Activity;
import android.graphics.Bitmap;
import android.widget.Toast;

import org.tensorflow.lite.DataType;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.support.common.FileUtil;
import org.tensorflow.lite.support.common.TensorProcessor;
import org.tensorflow.lite.support.common.ops.NormalizeOp;
import org.tensorflow.lite.support.image.ImageProcessor;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.image.ops.ResizeOp;
import org.tensorflow.lite.support.image.ops.ResizeWithCropOrPadOp;
import org.tensorflow.lite.support.image.ops.Rot90Op;
import org.tensorflow.lite.support.label.TensorLabel;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

import java.io.IOException;
import java.nio.MappedByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

public class ImageClassifier {

    // Non-Quantized
    private static final float PROBABILITY_MEAN = 0.0f;
    private static final float PROBABILITY_STD = 1.0f;

    private static final float IMAGE_STD = 127.5f;
    private static final float IMAGE_MEAN = 127.5f;

    private static final int MAX_SIZE =3;

    /**
     * Image size along the x axis.
     */
    private final int imageResizeX;
    /**
     * Image size along the y axis.
     */
    private final int imageResizeY;

    /**
     * Labels corresponding to the output of the vision model.
     */
    private final List<String> labels;

    /**
     * An instance of the driver class to run model inference with Tensorflow Lite.
     */
    private final Interpreter tensorClassifier;
    /**
     * Input image TensorBuffer.
     */
    private TensorImage inputImageBuffer;
    /**
     * Output probability TensorBuffer.
     */
    private final TensorBuffer probabilityImageBuffer;
    /**
     * Processer to apply post processing of the output probability.
     */
    private final TensorProcessor probabilityProcessor;

    /**
     * Creates a classifier
     *
     * @param activity the current activity
     * @throws IOException
     */
    public ImageClassifier(Activity activity) throws IOException {
        /*
         * The loaded TensorFlow Lite model.
         */
        MappedByteBuffer classifierModel = FileUtil.loadMappedFile(activity,
                "MobileNet.tflite");
        // Loads labels out from the label file.
        labels = FileUtil.loadLabels(activity, "labels_mobilenet.txt");

        tensorClassifier = new Interpreter(classifierModel, null);

        // Reads type and shape of input and output tensors, respectively. [START]
        int imageTensorIndex = 0; // input
        int probabilityTensorIndex = 0;// output

        int[] inputImageShape = tensorClassifier.getInputTensor(imageTensorIndex).shape();
        DataType inputDataType = tensorClassifier.getInputTensor(imageTensorIndex).dataType();

        int[] outputImageShape = tensorClassifier.getOutputTensor(probabilityTensorIndex).shape();
        DataType outputDataType = tensorClassifier.getOutputTensor(probabilityTensorIndex).dataType();

        imageResizeX = inputImageShape[2];
        imageResizeY = inputImageShape[1];


        // Creates the input tensor.
        inputImageBuffer = new TensorImage(inputDataType);

        // Creates the output tensor and its processor.
        probabilityImageBuffer = TensorBuffer.createFixedSize(outputImageShape, outputDataType);

        // Creates the post processor for the output probability.
        probabilityProcessor = new TensorProcessor.Builder().add(new NormalizeOp(PROBABILITY_MEAN, PROBABILITY_STD))
                .build();
    }

    /**
     * method runs the inference and returns the classification results
     *
     * @param bitmap            the bitmap of the image
     * @param sensorOrientation orientation of the camera
     * @return classification results
     */
    public List<Recognition> recognizeImage(final Bitmap bitmap, final int sensorOrientation) {
        // Lista con labels y probabilidades de cada clase
        List<Recognition> recognitions = new ArrayList<>();

        inputImageBuffer = loadImage(bitmap, sensorOrientation);
        tensorClassifier.run(inputImageBuffer.getBuffer(), probabilityImageBuffer.getBuffer().rewind()); ///

        // Gets the map of label and probability.
        Map<String, Float> labelledProbability = new TensorLabel(labels,
                probabilityProcessor.process(probabilityImageBuffer)).getMapWithFloatValue();

        int idLabel = 0;
        for (Map.Entry<String, Float> entry : labelledProbability.entrySet()) {
            recognitions.add(new Recognition(String.valueOf(idLabel), entry.getValue()));
            idLabel++;
        }        

        // Lista con probabilidades de cada clase
        List<Float> probabilidades = new ArrayList<>();
        for (Map.Entry<String, Float> entry : labelledProbability.entrySet()) {
            probabilidades.add(entry.getValue());
        }

        Collections.sort(recognitions);

        return recognitions.subList(0, MAX_SIZE);
    }

    /**
     * loads the image into tensor input buffer and apply pre processing steps
     *
     * @param bitmap            the bit map to be loaded
     * @param sensorOrientation the sensor orientation
     * @return the image loaded tensor input buffer
     */
    private TensorImage loadImage(Bitmap bitmap, int sensorOrientation) {
        // Loads bitmap into a TensorImage.
        inputImageBuffer.load(bitmap);

        int noOfRotations = sensorOrientation / 90;
        int cropSize = Math.min(bitmap.getWidth(), bitmap.getHeight());

        // pre processing steps are applied here
        ImageProcessor imageProcessor = new ImageProcessor.Builder()
                .add(new ResizeWithCropOrPadOp(cropSize, cropSize))
                .add(new ResizeOp(imageResizeX, imageResizeY, ResizeOp.ResizeMethod.BILINEAR))
                .add(new Rot90Op(noOfRotations))
                .add(new NormalizeOp(IMAGE_MEAN, IMAGE_STD))
                .build();
        return imageProcessor.process(inputImageBuffer);
    }

    /**
     * An immutable result returned by a Classifier describing what was recognized.
     */
    public class Recognition implements Comparable {
        /**
         * Display name for the recognition.
         */
        private String name;
        /**
         * A sortable score for how good the recognition is relative to others. Higher should be better.
         */
        private float confidence;

        public Recognition() {
        }

        public Recognition(String name, float confidence) {
            this.name = name;
            this.confidence = confidence;
        }

        public String getName() {
            return name;
        }

        public void setName(String name) {
            this.name = name;
        }

        public float getConfidence() {
            return confidence;
        }

        public void setConfidence(float confidence) {
            this.confidence = confidence;
        }

        @Override
        public String toString() {
            return "Recognition{" +
                    "name='" + name + '\'' +
                    ", confidence=" + confidence +
                    '}';
        }

        @Override
        public int compareTo(Object o) {
            return Float.compare(((Recognition) o).confidence, this.confidence);
        }
    }


}
0

There are 0 best solutions below