I am trying to upscale a image using esrgan.tflite on Android app. I could read the input image and upscale the image. But the output of the image is garbage. I am sure something in post processing is going wrong, But I can't put my finger on top of it.
Input image: 320 * 240; 4x upscaling so output image is 1280 * 960
The output image is just a rgb dots with no meaning. Can someone help me in where I am going wrong?
private static final int MODEL_INPUT_WIDTH = 320;
private static final int MODEL_INPUT_HEIGHT = 240;
private static final int MODEL_INPUT_CHANNELS = 3;
ByteBuffer preprocessedInput = preprocessFrame(videoFrame);
/**
* Output Image Size: 960 x 1280 x 3 pixels = 3,686,400 pixels
* Pixel Size (float32): 32 bits per float = 4 bytes per pixel
* Total Output Buffer Size: 3,686,400 pixels * 4 bytes/pixel = 14,745,600 bytes
* 1024 x 786
*/
// Run inference with the ESRGAN model
ByteBuffer outputBuffer = ByteBuffer.allocateDirect(14745600);
esrganInterpreter.run(preprocessedInput, outputBuffer);
// Postprocess the output buffer to get the upscaled frame
Bitmap upscaledFrame = postprocessFrame(outputBuffer);
private ByteBuffer preprocessFrame(Bitmap bitmap) {
int[] intValues = new int[MODEL_INPUT_WIDTH * MODEL_INPUT_HEIGHT];
float[] floatValues = new float[MODEL_INPUT_WIDTH * MODEL_INPUT_HEIGHT * MODEL_INPUT_CHANNELS];
// Get pixel values from the Bitmap
bitmap.getPixels(intValues, 0, MODEL_INPUT_WIDTH, 0, 0, MODEL_INPUT_WIDTH, MODEL_INPUT_HEIGHT);
// Normalize pixel values and store them in the input ByteBuffer
int pixel = 0;
for (int i = 0; i < MODEL_INPUT_HEIGHT; ++i) {
for (int j = 0; j < MODEL_INPUT_WIDTH; ++j) {
final int val = intValues[i * MODEL_INPUT_WIDTH + j]; // Calculate the correct index
floatValues[pixel * MODEL_INPUT_CHANNELS] = ((val >> 16) & 0xFF) / 255.0f;
floatValues[pixel * MODEL_INPUT_CHANNELS + 1] = ((val >> 8) & 0xFF) / 255.0f;
floatValues[pixel * MODEL_INPUT_CHANNELS + 2] = (val & 0xFF) / 255.0f;
pixel++;
}
}
// Copy the normalized pixel values to the input ByteBuffer
ByteBuffer buffer = ByteBuffer.allocateDirect(MODEL_INPUT_WIDTH * MODEL_INPUT_HEIGHT * MODEL_INPUT_CHANNELS * 4);
buffer.rewind();
for (int i = 0; i < floatValues.length; ++i) {
buffer.putFloat(floatValues[i]);
}
return buffer;
}
private Bitmap postprocessFrame(ByteBuffer outputBuffer) {
// Calculate the dimensions of the upscaled image
int upscaledWidth = MODEL_INPUT_WIDTH * 4;
int upscaledHeight = MODEL_INPUT_HEIGHT * 4;
// Create a bitmap to hold the upscaled image
Bitmap outputBitmap = Bitmap.createBitmap(upscaledWidth, upscaledHeight, Bitmap.Config.ARGB_8888);
// Allocate array to hold pixel values
int[] intValues = new int[upscaledWidth * upscaledHeight];
// Convert output ByteBuffer to pixel values
outputBuffer.rewind();
byte[] byteValues = new byte[upscaledWidth * upscaledHeight * MODEL_INPUT_CHANNELS];
outputBuffer.get(byteValues);
// Convert byte pixel values to ARGB
int index = 0;
for (int y = 0; y < upscaledHeight; ++y) {
for (int x = 0; x < upscaledWidth; ++x) {
final int r = byteValues[index] & 0xff;
final int g = byteValues[index + 1] & 0xff;
final int b = byteValues[index + 2] & 0xff;
intValues[y * upscaledWidth + x] = 0xff000000 | (r << 16) | (g << 8) | b;
index += MODEL_INPUT_CHANNELS;
}
}
// Set pixel values to the output Bitmap
outputBitmap.setPixels(intValues, 0, upscaledWidth, 0, 0, upscaledWidth, upscaledHeight);
return outputBitmap;
}