Optimize Android Java loop that copies an array with some elements zeroed, according to a FloatBuffer mask array

126 Views Asked by At

I have a function as below in Android which is a bit slow and is a bottleneck in my real-time app. I was thinking to somehow make it faster, maybe through GPU or any possible methods. What are ways to make it faster, more efficient using parallelization?

private int[] getArray (ByteBuffer byteBuffer) {
int[] array = new int[width * height];

FloatBuffer fb = byteBuffer.asFloatBuffer();
for (int i = 0; i < width* height; i++) {
    float probability = 1 - floatBuffer.get();
    if (probability > 0.9) {
        array[i] = originalBuffer[i];
    }
  }
return array;
}

The context is for segmentation task. Basically, a ML model returns a mask in bytebuffer. I pass it to this function to make the background pink. Here it is for reference.

2

There are 2 best solutions below

4
Cloverleaf On

(Too long for a comment) But what you can easily do: Replace

float probability = 1 - floatBuffer.get();
if (probability > 0.9) {

by

if (floatBuffer.get() < 0.1) {

And: I honestly don't know whether the compiler writes a machine code that calculates width*height in every loop pass for the abort criterion (I hope/think not). But give it a try: add the line

int product = width*height;
int[] array = new int[product];

and then set

for (int i = 0; i < product; i++) {

In this way, you have eliminated all arithmetic operations (apart of the indispensable i++ and <) in the loop.

10
Peter Cordes On

This will very likely compile / JIT to branchy asm because your loop only stores to array[i] on some iterations. Branchless asm is probably faster if the pattern isn't predictable, and that also enables auto-vectorization if the compiler / JIT is smart enough, like GCC and clang do for a C or C++ version. Godbolt. (Like most SIMD ISAs, AArch64 NEON has SIMD float compares to produce a mask of 0 / -1 elements which works great to bitwise AND another vector, zeroing elements or not.)

Compilers are generally reluctant to invent writes because that can be a thread-safety problem (for arrays that any other thread might have a reference to). And inventing reads from originalBuffer could go past the end of the array: it would be valid to call your function with an originalBuffer of length 100, but a FloatBuffer such that if (probability > 0.9) was false for all elements past that.

Using int tmp_orig = orig[i];
array[i] = (fb.get() < 0.1f) ? tmp_orig : 0; fixes both these problems.

1-x > 0.9 is a slow way to write x < 0.1. (And a compiler won't translate for you, since neither 0.9 nor 0.1 are exactly representable as float or double). Speaking of float vs. double, 0.9 or 0.1 have type double, which forces the compiler to convert float to double for the compare (an extra fcvt instruction). Use x < 0.1f to allow a float comparison.


Godbolt has Android Java tools, including the optimizing dex2oat compiler (whatever that is), so I could look at AArch64 asm for Java source. If I could get something to compile: on any errors, the compiler output is basically the --help output, nothing helpful to telling you what it didn't like about the program.

I wrote a version that uses float[] instead of a ByteBuffer for simplicity since I was able to get that to compile. fb.get() or .getFloat() should be equivalent to mask[i] to iterate through the bytebuffer, that's hopefully a drop-in replacement.

class Example {
    // Simplified example that compiles without any extra stuff.
    // the same thing with a ByteBuffer and .getFloat() hopefully works the same way
    static int[] getArray_non_fb (float[] mask, int [] originalBuffer, int len) {
        int[] array = new int[len];

        for (int i = 0; i < len ; i++) {
            int tmp_orig = originalBuffer[i];  // unconditional read
            array[i] = (mask[i] < 0.1f) ? tmp_orig : 0;  // unconditional write
            // bytebuffer.getFloat() instead of mask[i] hopefully compiles the same
        }
        return array;
    }
}

This compiles branchlessly (Godbolt), note the csel (AArch64 conditional select, like x86 cmov) instead of b.hi for the original. And with no fcvt or fsub.

    0x0000109c    cmp w1, w24                 # w1 is the loop counter i
    0x000010a0    b.ge #+0x30 (addr 0x10d0)
    0x000010a4    add w2, w22, #0xc (12)      # re-generate a pointer to originalBuffer
    0x000010a8    add w3, w23, #0xc (12)      # same for mask
    0x000010ac    ldr s1, [x2, x1, lsl #2]    # mask[i] load float
    0x000010b0    ldr w2, [x3, x1, lsl #2]    # int tmp_orig = originalBuffer[i]
    0x000010b4    add w3, w0, #0xc (12)       # generate pointer to output array[]
    0x000010b8    fcmp s1, s0                  # fcmp mask[i], 0.1f
    0x000010bc    csel w2, w2, wzr, lo         # conditionally zero tmp_orig
    0x000010c0    str w2, [x3, x1, lsl #2]     # array[i] = ternary result
    0x000010c4    add w1, w1, #0x1 (1)         # ++i
    0x000010c8    ldr x21, [x21]             # WTF?  Linked list pointer chasing creates a latency bottleneck?
     StackMap[3]   native_pc=0x10cc, dex_pc=0x4, register_mask=0xc00001, stack_mask=0b
    0x000010cc    b #-0x30 (addr 0x109c)

This appears to be the inner loop. It is a loop, and the only part of the asm that contains fcmp. It seems pretty horrible, with the condition at the top and an unconditional branch at the bottom, unlike an idiomatic asm loop, and redundant integer add that could have been hoisted out of the loop.

This isn't auto-vectorized. And the ldr x21, [x21] limits the loop to at best 1 iteration per 3 to 4 clocks (load-use latency) if that's really there in the asm that would actually run. No idea what that's for; perhaps so another thread can change a memory location to make this thread fault and have the JVM regain control in this thread? Seems like a very expensive way to do that if an instruction like this appears in all loops.

A core like Cortex-A710 (chipsandcheese) could be doing about 4 elements per clock cycle with ASIMD (like clang's AArch64 loop form the C++ version) instead of 1 every 2 (or per 4 with the ldr x21, [x21] bottleneck), so an 8x speedup over this branchless asm is theoretically possible. Or more with SVE if Android enables that. But I have no clue how to get a JVM to do better, if this is what would actually run on an Android device.

If branch prediction was a problem, this version might run a few times faster than your original, especially if data was hot in some level of cache.

Or not on an out-of-order exec CPU that can see the branch condition several iterations ahead of the ldr x21, [x21] dependency chain and already recover from the branch miss without losing any progress on the actual bottleneck. (Avoid stalling pipeline by calculating conditional early)

Again, IDK exactly what that load instruction is doing, but it's a disaster for performance on high-end AArch64 CPUs with wide pipelines (many instructions per clock) and load-use latency higher than 2.

It's also disappointing (but not unexpected) that the code didn't auto-vectorize with ASIMD instructions the way C++ compilers do (see the first paragraph.) Perhaps there's another JIT optimization pass and this asm isn't what actually executes?

These source transformations make things as easy as possible for a compiler to optimize, but you still need the compiler to do its share of the work. Or to manually vectorize with SIMD extensions if they're available in Java for Android.