Getting garbage image while upscaling a image using esrgan.tflite in Android

48 Views Asked by At

I am trying to upscale a image using esrgan.tflite on Android app. I could read the input image and upscale the image. But the output of the image is garbage. I am sure something in post processing is going wrong, But I can't put my finger on top of it.

Input image: 320 * 240; 4x upscaling so output image is 1280 * 960

The output image is just a rgb dots with no meaning. Can someone help me in where I am going wrong?

    private static final int MODEL_INPUT_WIDTH = 320;
    private static final int MODEL_INPUT_HEIGHT = 240;
    private static final int MODEL_INPUT_CHANNELS = 3;

                ByteBuffer preprocessedInput = preprocessFrame(videoFrame);

                /**
                 * Output Image Size: 960 x 1280 x 3 pixels = 3,686,400 pixels
                 * Pixel Size (float32): 32 bits per float = 4 bytes per pixel
                 * Total Output Buffer Size: 3,686,400 pixels * 4 bytes/pixel = 14,745,600 bytes
                 * 1024 x 786
                 */
                // Run inference with the ESRGAN model
                ByteBuffer outputBuffer = ByteBuffer.allocateDirect(14745600);
                esrganInterpreter.run(preprocessedInput, outputBuffer);

                // Postprocess the output buffer to get the upscaled frame
                Bitmap upscaledFrame = postprocessFrame(outputBuffer);

    private ByteBuffer preprocessFrame(Bitmap bitmap) {
        int[] intValues = new int[MODEL_INPUT_WIDTH * MODEL_INPUT_HEIGHT];
        float[] floatValues = new float[MODEL_INPUT_WIDTH * MODEL_INPUT_HEIGHT * MODEL_INPUT_CHANNELS];

        // Get pixel values from the Bitmap
        bitmap.getPixels(intValues, 0, MODEL_INPUT_WIDTH, 0, 0, MODEL_INPUT_WIDTH, MODEL_INPUT_HEIGHT);

        // Normalize pixel values and store them in the input ByteBuffer
        int pixel = 0;
        for (int i = 0; i < MODEL_INPUT_HEIGHT; ++i) {
            for (int j = 0; j < MODEL_INPUT_WIDTH; ++j) {
                final int val = intValues[i * MODEL_INPUT_WIDTH + j]; // Calculate the correct index
                floatValues[pixel * MODEL_INPUT_CHANNELS] = ((val >> 16) & 0xFF) / 255.0f;
                floatValues[pixel * MODEL_INPUT_CHANNELS + 1] = ((val >> 8) & 0xFF) / 255.0f;
                floatValues[pixel * MODEL_INPUT_CHANNELS + 2] = (val & 0xFF) / 255.0f;
                pixel++;
            }
        }

        // Copy the normalized pixel values to the input ByteBuffer
        ByteBuffer buffer = ByteBuffer.allocateDirect(MODEL_INPUT_WIDTH * MODEL_INPUT_HEIGHT * MODEL_INPUT_CHANNELS * 4);
        buffer.rewind();
        for (int i = 0; i < floatValues.length; ++i) {
            buffer.putFloat(floatValues[i]);
        }

        return buffer;
    }


    private Bitmap postprocessFrame(ByteBuffer outputBuffer) {
        // Calculate the dimensions of the upscaled image
        int upscaledWidth = MODEL_INPUT_WIDTH * 4;
        int upscaledHeight = MODEL_INPUT_HEIGHT * 4;

        // Create a bitmap to hold the upscaled image
        Bitmap outputBitmap = Bitmap.createBitmap(upscaledWidth, upscaledHeight, Bitmap.Config.ARGB_8888);

        // Allocate array to hold pixel values
        int[] intValues = new int[upscaledWidth * upscaledHeight];

        // Convert output ByteBuffer to pixel values
        outputBuffer.rewind();
        byte[] byteValues = new byte[upscaledWidth * upscaledHeight * MODEL_INPUT_CHANNELS];
        outputBuffer.get(byteValues);

        // Convert byte pixel values to ARGB
        int index = 0;
        for (int y = 0; y < upscaledHeight; ++y) {
            for (int x = 0; x < upscaledWidth; ++x) {
                final int r = byteValues[index] & 0xff;
                final int g = byteValues[index + 1] & 0xff;
                final int b = byteValues[index + 2] & 0xff;
                intValues[y * upscaledWidth + x] = 0xff000000 | (r << 16) | (g << 8) | b;
                index += MODEL_INPUT_CHANNELS;
            }
        }

        // Set pixel values to the output Bitmap
        outputBitmap.setPixels(intValues, 0, upscaledWidth, 0, 0, upscaledWidth, upscaledHeight);

        return outputBitmap;
    }

0

There are 0 best solutions below