ML Kit face detector processing time grows with every image

245 Views Asked by At

I am writing an android app that reads RTSP stream and sends images to ML Kit's FaceDetector. I used rtsp-client-android library and modified it in a way so that it does not render image onto its custom RTSPSurfaceView, but rather puts images into a BlockingQueue in a VIdeoDecodeThread class.

package com.alexvas.rtsp.codec

import android.graphics.Bitmap
import android.graphics.BitmapFactory
import android.graphics.ImageFormat
import android.graphics.Rect
import android.graphics.YuvImage
import android.media.Image
import android.media.MediaCodec
import android.media.MediaCodec.OnFrameRenderedListener
import android.media.MediaCodecInfo
import android.media.MediaFormat
import android.util.Log
import android.util.Log.VERBOSE
import com.google.android.exoplayer2.util.Util
import com.google.mlkit.vision.common.InputImage
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.nio.ByteBuffer
import java.util.concurrent.ArrayBlockingQueue
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean


class VideoDecodeThread (
        private val mimeType: String,
        private val width: Int,
        private val height: Int,
        private val videoFrameQueue: FrameQueue,
        private val onFrameRenderedListener: OnFrameRenderedListener) : Thread() {

    private var exitFlag: AtomicBoolean = AtomicBoolean(false)
    var videoQueue : ArrayBlockingQueue<Bitmap> = ArrayBlockingQueue(60)
    lateinit var bitmap : Bitmap
    val options = BitmapFactory.Options().apply {
        // Set your desired options here
        inJustDecodeBounds = true
    }

    fun stopAsync() {
        if (DEBUG) Log.v(TAG, "stopAsync()")
        exitFlag.set(true)
        // Wake up sleep() code
        interrupt()
    }




    private fun getDecoderSafeWidthHeight(decoder: MediaCodec): Pair<Int, Int> {
        val capabilities = decoder.codecInfo.getCapabilitiesForType(mimeType).videoCapabilities
        return if (capabilities.isSizeSupported(width, height)) {
            Pair(width, height)
        } else {
            val widthAlignment = capabilities.widthAlignment
            val heightAlignment = capabilities.heightAlignment
            Pair(
                Util.ceilDivide(width, widthAlignment) * widthAlignment,
                Util.ceilDivide(height, heightAlignment) * heightAlignment)
        }
    }

    override fun run() {
        if (DEBUG) Log.d(TAG, "$name started")

        try {
            val decoder = MediaCodec.createDecoderByType(mimeType)
            val widthHeight = getDecoderSafeWidthHeight(decoder)
            val format = MediaFormat.createVideoFormat(mimeType, widthHeight.first, widthHeight.second)
            //val codecImageFormat = getImageFormatFromCodecType(mime);

            decoder.setOnFrameRenderedListener(onFrameRenderedListener, null)

            if (DEBUG) Log.d(TAG, "Configuring surface ${widthHeight.first}x${widthHeight.second} w/ '$mimeType', max instances: ${decoder.codecInfo.getCapabilitiesForType(mimeType).maxSupportedInstances}")
            decoder.configure(format, null, null, 0)

            decoder.start()
            if (DEBUG) Log.d(TAG, "Started surface decoder")

            val bufferInfo = MediaCodec.BufferInfo()
            // Main loop
            while (!exitFlag.get()) {
                val inIndex: Int = decoder.dequeueInputBuffer(DEQUEUE_INPUT_TIMEOUT_US)
                if (inIndex >= 0) {
                    // fill inputBuffers[inputBufferIndex] with valid data
                    val byteBuffer: ByteBuffer? = decoder.getInputBuffer(inIndex)
                    byteBuffer?.rewind()
                    // Preventing BufferOverflowException
                    // if (length > byteBuffer.limit()) throw DecoderFatalException("Error")

                    val frame = videoFrameQueue.pop()
                    if (frame == null) {
                        Log.d(TAG, "Empty video frame")
                        // Release input buffer
                        decoder.queueInputBuffer(inIndex, 0, 0, 0L, 0)
                    } else {
                        byteBuffer?.put(frame.data, frame.offset, frame.length)
                        decoder.queueInputBuffer(inIndex, frame.offset, frame.length, frame.timestamp, 0)
                    }
                }
                if (exitFlag.get()) break
                when (val outIndex = decoder.dequeueOutputBuffer(bufferInfo, DEQUEUE_OUTPUT_BUFFER_TIMEOUT_US)) {
                    MediaCodec.INFO_OUTPUT_FORMAT_CHANGED -> Log.d(TAG, "Decoder format changed: ${decoder.outputFormat}")
                    MediaCodec.INFO_TRY_AGAIN_LATER -> if (DEBUG) Log.d(TAG, "No output from decoder available")
                    else -> {
                        if (outIndex >= 0) {
                            //val outputBuffer: ByteBuffer = decoder.getOutputBuffer(outIndex)!!
                            //val bufferFormat: MediaFormat = decoder.getOutputFormat(outIndex)
                            //
                            var image = decoder.getOutputImage(outIndex)


                            image?.let {

                                val yuvImage = YuvImage(
                                    YUV_420_888toNV21(image),
                                    ImageFormat.NV21,
                                    480,
                                    360,
                                    null
                                )

                                val stream = ByteArrayOutputStream()
                                yuvImage.compressToJpeg(Rect(0, 0, 480, 360), 80, stream)
                                bitmap = BitmapFactory.decodeByteArray(
                                    stream.toByteArray(),
                                    0,
                                    stream.size(),
                                )
                                try {
                                    stream.close()
                                } catch (e:IOException) {
                                    e.printStackTrace()
                                }

                                bitmap.let {
                                    if (!videoQueue.offer(bitmap)){
                                        videoQueue.poll()
                                        //videoQueue.offer(bitmap, 10, TimeUnit.MILLISECONDS)
                                        videoQueue.add(bitmap)
                                    }

                                }
                                image.close();
                            } ?: run {
                                Log.v("aaa", "image is null")
                            }

                             //NOTICE change that to just offer(buffer) if needed
                            decoder.releaseOutputBuffer(
                                outIndex,
                                //bufferInfo.size != 0 && !exitFlag.get()
                                false
                            )
                            Log.v("aaa", "image sent for processing")
                        }
                    }
                }

                // All decoded frames have been rendered, we can stop playing now
                if (bufferInfo.flags and MediaCodec.BUFFER_FLAG_END_OF_STREAM != 0) {
                    if (DEBUG) Log.d(TAG, "OutputBuffer BUFFER_FLAG_END_OF_STREAM")
                    break
                }
            }

            // Drain decoder
            val inIndex: Int = decoder.dequeueInputBuffer(DEQUEUE_INPUT_TIMEOUT_US)
            if (inIndex >= 0) {
                decoder.queueInputBuffer(inIndex, 0, 0, 0L, MediaCodec.BUFFER_FLAG_END_OF_STREAM)
            } else {
                Log.w(TAG, "Not able to signal end of stream")
            }

            decoder.stop()
            decoder.release()
            videoFrameQueue.clear()

        } catch (e: Exception) {
            Log.e(TAG, "$name stopped due to '${e.message}'")
            // While configuring stopAsync can be called and surface released. Just exit.
            if (!exitFlag.get()) e.printStackTrace()
            return
        }

        if (DEBUG) Log.d(TAG, "$name stopped")
    }



    private fun YUV_420_888toNV21(image: Image): ByteArray? {
        val nv21: ByteArray
        val yBuffer = image.planes[0].buffer
        val uBuffer = image.planes[1].buffer
        val vBuffer = image.planes[2].buffer
        val ySize = yBuffer.remaining()
        val uSize = uBuffer.remaining()
        val vSize = vBuffer.remaining()

        val yBytes = ByteArray(ySize)
        yBuffer.get(yBytes)
        val uBytes = ByteArray(uSize)
        uBuffer.get(uBytes)
        val vBytes = ByteArray(vSize)
        vBuffer.get(vBytes)

        // Downscale the Y, U, and V planes to the desired resolution
        var downscaledYBytes = downscaleYPlane(yBytes, 1280, 720, 480, 360)
        var downscaledUBytes = downscaleUVPlane(uBytes, 1280 / 2, 720 / 2, 480 / 2, 360 / 2)
        var downscaledVBytes = downscaleUVPlane(vBytes, 1280 / 2, 720 / 2, 480 / 2, 360 / 2)

        // Convert the downscaled YUV data to NV21 format
        nv21 = ByteArray(480 * 360 + (480 / 2) * (360 / 2) * 2)
        System.arraycopy(downscaledYBytes, 0, nv21, 0, downscaledYBytes.size)
        for (i in downscaledVBytes.indices) {
            nv21[downscaledYBytes.size + i * 2] = downscaledVBytes[i]
            nv21[downscaledYBytes.size + i * 2 + 1] = downscaledUBytes[i]
        }

        yBuffer.clear()
        uBuffer.clear()
        vBuffer.clear()


        return nv21
    }

    private fun downscaleYPlane(src: ByteArray, srcWidth: Int, srcHeight: Int,
                                dstWidth: Int, dstHeight: Int): ByteArray {
        val dst = ByteArray(dstWidth * dstHeight)

        for (y in 0 until dstHeight) {
            for (x in 0 until dstWidth) {
                val srcX = x * srcWidth / dstWidth
                val srcY = y * srcHeight / dstHeight
                dst[y * dstWidth + x] = src[srcY * srcWidth + srcX]
            }
        }

        return dst
    }

    private fun downscaleUVPlane(src: ByteArray, srcWidth: Int, srcHeight: Int,
                                 dstWidth: Int, dstHeight: Int): ByteArray {
        val dst = ByteArray(dstWidth * dstHeight)

        for (y in 0 until dstHeight) {
            for (x in 0 until dstWidth) {
                val srcX = x * srcWidth / dstWidth
                val srcY = y * srcHeight / dstHeight
                dst[y * dstWidth + x] = src[srcY * srcWidth + srcX]
            }
        }

        return dst
    }

    companion object {
        private val TAG: String = VideoDecodeThread::class.java.simpleName
        private const val DEBUG = false

        private val DEQUEUE_INPUT_TIMEOUT_US = TimeUnit.MILLISECONDS.toMicros(500)
        private val DEQUEUE_OUTPUT_BUFFER_TIMEOUT_US = TimeUnit.MILLISECONDS.toMicros(100)
    }

}

This is how bitmaps are taken and processed with ML Kit:

package com.barreloftea.driversupport.domain.imageprocessor.service;

import android.graphics.Bitmap;
import android.graphics.PointF;
import android.graphics.Rect;
import android.util.Log;

import com.barreloftea.driversupport.domain.imageprocessor.interfaces.VideoRepository;
import com.barreloftea.driversupport.domain.imageprocessor.utils.DrawContours;
import com.barreloftea.driversupport.domain.processor.Processor;
import com.barreloftea.driversupport.domain.processor.common.ImageBuffer;
import com.google.android.gms.tasks.Task;
import com.google.mlkit.vision.common.InputImage;
import com.google.mlkit.vision.face.Face;
import com.google.mlkit.vision.face.FaceContour;
import com.google.mlkit.vision.face.FaceDetection;
import com.google.mlkit.vision.face.FaceDetector;
import com.google.mlkit.vision.face.FaceDetectorOptions;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;

public class ImageProcessor extends Thread {

    public static final String TAG = ImageProcessor.class.getSimpleName();

    private AtomicBoolean exitFlag = new AtomicBoolean(false);
    VideoRepository videoRepository;
    ArrayBlockingQueue<Bitmap> queue;
    ImageBuffer imageBuffer;
    Processor processor;

    public ImageProcessor(VideoRepository rep){
        videoRepository = rep;
        videoRepository.setParams("rtsp://192.168.0.1:554/livestream/12", "", "");
        videoRepository.prepare();
        imageBuffer = ImageBuffer.getInstance();
    }

    public void setProcessor(Processor processor) {
        this.processor = processor;
    }

    int eyeFlag;
    int mouthFlag;
    int noseFlag;
    int notBlinkFlag;
    static final int EYE_THRESH = 16;
    static final int MOUTH_THRESH = 18;
    static final int NO_BLINK_TH = 80;
    static final float ROUND = 0.6f;

    public float EOP_ = 0.3f;
    public float MOR_ = 0.5f;
    public float NL_ = 0.5f;
    private float lastEOP;
    private float lastMOR;
    private float lastNL;

    private DrawContours drawer = new DrawContours();

    private Bitmap bitmap;
    InputImage inputImage;
    private FaceDetectorOptions realTimeOpts = new FaceDetectorOptions.Builder()
            .setPerformanceMode(FaceDetectorOptions.PERFORMANCE_MODE_FAST)
            .setContourMode(FaceDetectorOptions.CONTOUR_MODE_ALL)
            .build();
    private FaceDetector detector = FaceDetection.getClient(realTimeOpts);

    public void stopAsync(){
        exitFlag.set(true);
        interrupt();
        Log.v(TAG, "camara thread is stopped");
    }


    @Override
    public void run() {
        queue = videoRepository.getVideoQueue();
        Log.v(TAG, "camera thread started");
        while(!exitFlag.get()){

            try {
                //byteBuffer = queue.take();
                bitmap = queue.take();
                Log.v(TAG, "image is taken from queue");
            } catch (InterruptedException e) {
                //throw new RuntimeException(e);
                Log.v(TAG, "no bitmap available in queue");
            }

            inputImage = InputImage.fromBitmap(bitmap, 0);


            long startTime = System.nanoTime();

            Task<List<Face>> result =
                    detector.process(inputImage)
                            .addOnSuccessListener(
                                    faces -> {

                                        Log.v(TAG, faces.size() + " FACES WERE DETECTED");

                                        for (Face face : faces){
                                            Rect bounds = face.getBoundingBox();
                                            bitmap = drawer.drawRect(bitmap, bounds);


                                            float rotY = face.getHeadEulerAngleY();  // Head is rotated to the right rotY degrees
                                            float rotZ = face.getHeadEulerAngleZ(); //TODO rotY and rotZ are somehow always 0.0 and -0.0
                                            float rotX = face.getHeadEulerAngleX();

                                            List<PointF> leftEyeContour = face.getContour(FaceContour.LEFT_EYE).getPoints();
                                            bitmap = drawer.drawContours(bitmap, leftEyeContour);
                                            List<PointF> rightEyeContour = face.getContour(FaceContour.RIGHT_EYE).getPoints();
                                            bitmap = drawer.drawContours(bitmap, rightEyeContour);
                                            List<PointF> upperLipCon = face.getContour(FaceContour.UPPER_LIP_TOP).getPoints();
                                            bitmap = drawer.drawContours(bitmap, upperLipCon);
                                            List<PointF> lowerLipCon = face.getContour(FaceContour.LOWER_LIP_BOTTOM).getPoints();
                                            bitmap = drawer.drawContours(bitmap, lowerLipCon);
                                            List<PointF> noseCon = face.getContour(FaceContour.NOSE_BRIDGE).getPoints();
                                            bitmap = drawer.drawContours(bitmap, noseCon);


                                            float REOP = getOneEOP(rightEyeContour);
                                            float LEOP = getOneEOP(leftEyeContour);


                                            notBlinkFlag++;

                                            lastEOP = (LEOP+REOP)/2;

                                            Log.v(TAG, "last eop is" + lastEOP);

                                            if ((LEOP+REOP)/2 < EOP_) {
                                                eyeFlag++;
                                                notBlinkFlag = 0;
                                                Log.v(null, "you blinked");
                                            }
                                            else {
                                                eyeFlag = 0;
                                            }

                                            if (eyeFlag>=EYE_THRESH){
                                                processor.setCamState(Processor.SLEEPING);
                                                Log.v(null, "REASON closed eyes");
                                            }
                                            if (notBlinkFlag > NO_BLINK_TH){
                                                processor.setCamState(Processor.DROWSY);
                                                Log.v(null, "REASON always open eyes");
                                            }



                                            if(eyeFlag<EYE_THRESH && mouthFlag<MOUTH_THRESH && noseFlag<EYE_THRESH) {

                                                        processor.setBandState(Processor.AWAKE);

                                            }


                                            Log.v(null, rotY + " roty");
                                            Log.v(null, rotZ + " rotz");
                                            Log.v(null, rotX + " rotx");

                                            long endTime = System.nanoTime();
                                            long timePassed = endTime - startTime;
                                            Log.v(null, "Execution time in milliseconds: " + timePassed / 1000000);
                                        }

                                    })
                            .addOnFailureListener(
                                    e -> Log.v(TAG, "IMAGE PROCESSING FAILED"+ Arrays.toString(e.getStackTrace())+ e.getMessage()))
                            .addOnCompleteListener(
                                    task -> {
                                        Log.v(TAG, "IMGAE IS PROCESSED SUCCESSFULLY");
                                        //imageBuffer.setFrame(bitmap);
                                    }
                            );
                    inputImage = null;
        }
    }

Expected behaviour: images processed within 500-600 milliseconds. This is measured by Execution time written in ImageProcessor code.

Actual behaviour: Execution time grows with every iteration. Sometimes it reaches 61222 milliseconds until either I or system kills the app for extensive memory usage. Normally it consumes around 139 Mb, sometimes 450Mb, sometimes it reaches 800+Mb of RAM.

What I tried: I assumed it happens because some image object does not get cleared but before reuse, but stores all previous layers of data. However, bitmap.getByteCount() in ImageProcessor shows the same size every time. I tried to reuse all bitmaps and streams that are there in code, no result either. Android Profiler analysis shows that byte[] consume the most memory, but I have no idea how to track whose byte arrays those are.

This is part of logcat:

2023-06-15 08:52:09.386   374-374   a.driversuppor          com.barreloftea.driversupport        V  Execution time in milliseconds: 21818
2023-06-15 08:52:09.398   374-2088  skia                    com.barreloftea.driversupport        D  onFlyCompress
2023-06-15 08:52:09.399   374-2088  skia                    com.barreloftea.driversupport        D  Yuv420SpToJpeg [yPlanar:29] [vuPlanar:126] [WxH:480x360]
2023-06-15 08:52:09.405   374-2091  aaa                     com.barreloftea.driversupport        V  image is taken from queue
2023-06-15 08:52:09.405   374-2088  aaa                     com.barreloftea.driversupport        V  image sent for processing
2023-06-15 08:52:09.417   374-374   aaa                     com.barreloftea.driversupport        V  1 FACES WERE DETECTED
2023-06-15 08:52:09.423   374-2088  skia                    com.barreloftea.driversupport        D  onFlyCompress
2023-06-15 08:52:09.423   374-2088  skia                    com.barreloftea.driversupport        D  Yuv420SpToJpeg [yPlanar:29] [vuPlanar:126] [WxH:480x360]
2023-06-15 08:52:09.429   374-2091  aaa                     com.barreloftea.driversupport        V  image is taken from queue
2023-06-15 08:52:09.430   374-2088  aaa                     com.barreloftea.driversupport        V  image sent for processing
2023-06-15 08:52:09.433   374-374   a.driversuppor          com.barreloftea.driversupport        V  0.0 roty
2023-06-15 08:52:09.433   374-374   a.driversuppor          com.barreloftea.driversupport        V  -0.0 rotz
2023-06-15 08:52:09.433   374-374   a.driversuppor          com.barreloftea.driversupport        V  0.0 rotx
2023-06-15 08:52:09.434   374-374   aaa                     com.barreloftea.driversupport        V  IMGAE IS PROCESSED SUCCESSFULLY
2023-06-15 08:52:09.434   374-374   a.driversuppor          com.barreloftea.driversupport        V  Execution time in milliseconds: 21839

What can be the problem and how can I optimize code to decrease execution time (and consequently memory usage)?

0

There are 0 best solutions below