Pass function as static class for fast numerics in Java

229 Views Asked by At

I want to make some numerical computations in java, and to make operation really modular, I want pass functions as parameters of other functions. I was searching and normally it is done in java using class which warp the function. I realy don't need instantiate these classes (there are no data inside) and I want to make it as fast as possible (somewhere was writen that final static methods are inlined by JIT compiler). So I made something like this

public static class Function2 {
  public static float eval(float a, float b){ return Float.NaN; }  
}

public static class FAdd extends Function2 {
  public static float eval(float a, float b){ return a+b; }  
}

public static class Fmult extends Function2 {
  public static float eval(float a, float b){ return a*b; }  
}

void arrayOp( float [] a, float [] b, float [] out, Function2 func ){
  for (int i=0; i<a.length; i++){     out[i] = func.eval( a[i], b[i] );   }
}

float [] a,b, out;

void setup(){
  println( FAdd.eval(10,20) );
  arrayOp( a,b, out, FAdd );
}

However it prints error: "Cannot find anything like FAdd" when I try to pass it to arrayOp, even though println( FAdd.eval(10,20) ) works fine. So it seem that for some reason it is just impossible to pass static class as a prameter.

What you recommand to solve such task? I actualy want FAdd to be something like macro, nad arrayOp be polymorf (behave depending of which macro I pass in). But ideal would be if it would be resolved in compile time (not in runtime) to improve numerical speed. The compiled result should be the same as if I would write

void arrayAdd( float [] a, float [] b, float [] out ){
  for (int i=0; i<a.length; i++){     out[i] = a[i]  + b[i];    }
}
void arrayMult( float [] a, float [] b, float [] out ){
  for (int i=0; i<a.length; i++){     out[i] = a[i] * b[i];   }
} 
6

There are 6 best solutions below

5
On

Have you considered using enums?

private void test() {
  test(3.0f, 4.0f, F.Add);
  test(3.0f, 4.0f, F.Sub);
  test(3.0f, 4.0f, F.Mul);
  test(3.0f, 4.0f, F.Div);
  float[] a = {1f, 2f, 3f, 4f, 5f};
  float[] b = {4f, 9f, 16f, 25f, 36f};
  test(a, b, F.Add);
  test(a, b, F.Sub);
  test(a, b, F.Mul);
  test(a, b, F.Div);
}

private void test(float[] a, float[] b, F f) {
  System.out.println(Arrays.toString(a) + " " + f + " " + Arrays.toString(b) + " = " + Arrays.toString(f.f(a, b, f)));
}

private void test(float a, float b, F f) {
  System.out.println(a + " " + f + " " + b + " = " + f.f(a, b));
}

public enum F {
  Add {
    @Override
    public float f(float x, float y) {
      return x + y;
    }

    @Override
    public String toString() {
      return "+";
    }
  },
  Sub {
    @Override
    public float f(float x, float y) {
      return x - y;
    }

    @Override
    public String toString() {
      return "-";
    }
  },
  Mul {
    @Override
    public float f(float x, float y) {
      return x * y;
    }

    @Override
    public String toString() {
      return "*";
    }
  },
  Div {
    @Override
    public float f(float x, float y) {
      return x / y;
    }

    @Override
    public String toString() {
      return "/";
    }
  };

  // Evaluate to a new array.
  static float[] f(float[] x, float[] y, F f) {
    float[] c = new float[x.length];
    for (int i = 0; i < x.length; i++) {
      c[i] = f.f(x[i], y[i]);
    }
    return c;
  }

  // All must have an f(x,y) method.
  public abstract float f(float x, float y);

  // Also offer a toString - defaults to the enum name.  
  @Override
  public String toString() {
    return this.name();
  }
}

Prints:

3.0 + 4.0 = 7.0
3.0 - 4.0 = -1.0
3.0 * 4.0 = 12.0
3.0 / 4.0 = 0.75
[1.0, 2.0, 3.0, 4.0, 5.0] + [4.0, 9.0, 16.0, 25.0, 36.0] = [5.0, 11.0, 19.0, 29.0, 41.0]
[1.0, 2.0, 3.0, 4.0, 5.0] - [4.0, 9.0, 16.0, 25.0, 36.0] = [-3.0, -7.0, -13.0, -21.0, -31.0]
[1.0, 2.0, 3.0, 4.0, 5.0] * [4.0, 9.0, 16.0, 25.0, 36.0] = [4.0, 18.0, 48.0, 100.0, 180.0]
[1.0, 2.0, 3.0, 4.0, 5.0] / [4.0, 9.0, 16.0, 25.0, 36.0] = [0.25, 0.22222222, 0.1875, 0.16, 0.1388889]
0
On

What you want to achieve is actually the functionality of anonymous function or lambda expression, which is in the JSR 335 (Lambda Expressions for the Java Programming Language) and will be available in Java 8. Currently, only anonymous inner class is close to that. This question( What's the nearest substitute for a function pointer in Java? ) in stackoverflow may help you.

0
On

You are actually mixing up instances and classes in your implementation. When you have a method declared like this:

void arrayOp( float [] a, float [] b, float [] out, Function2 func ){
   for (int i=0; i<a.length; i++){     out[i] = func.eval( a[i], b[i] );   }
}

You are basically saying that you expect an instance of class Function2, and not really a class parameter. Also this statement is syntactically incorrect:

arrayOp( a,b, out, FAdd );

So lets say you want to send the class itself to a method, then your declaration of arrayOp will look something like:

void arrayOp( float [] a, float [] b, float [] out, Class func ){

And when you call this method you will pass in the parameter in this way:

arrayOp( a,b, out, FAdd.class );

But static methods cannot be overridden via inheritance. You need a completely different implementation for achieving your goals. That said @OldCurmudgeon has presented a really nice solution to your problem. Consider using that.

0
On

You're making some massive assumptions that the fastest code will be only if its a final static method. You are most likely wrong, and should be focusing on architecting it properly and testing for performance.

One method is using enemy's, as stated above. I'd say what yo should do is have an interface with the eval function. You can then pass in an implementation of the interface.

The Java VM will realize optimize that code appropriately.

1
On

Static methods can not be overridden, but you can do it with an anonymous class:

public static class Function2 {
    public float eval(float a, float b){ return Float.NaN; }  
}

arrayOp(a, b, out, new Function2() {
    public float eval(float a, float b){
        return FAdd.eval(a, b);
    }});

Note that the method declaration in of eval() in Function2 is not static.

0
On

I made some tests, and it seems that really, there is no need to try optimize it on modern machines.

Machine 1 - (my older home computer) 32bit WinXP, Intel Pentium 3, (I'm not sure about java version) For both operations float.mult and float.add the static version is more than 2x faster

static  100000000 [ops]  406.0 [s]  4.06 [ns/op] 
dynamic 100000000 [ops]  1188.0 [s]  11.88 [ns/op] 

but for float Sqrt the difference is already very small

static  100000000 [ops]  922.0 [s]  9.22 [ns/op] 
dynamic 100000000 [ops]  1172.0 [s]  11.719999 [ns/op] 

Machine 2 - (my computer at work) - 64bit ubuntu 12.04LTS, Intel Core5, java version "1.6.0_12-ea, Java(TM) SE Runtime Environment (build 1.6.0_12-ea-b02), Java HotSpot(TM) 64-Bit Server VM (build 11.2-b01, mixed mode) The results are much better (for float.add):

static  1000000000 [ops]  1747.0 [s]  1.7470001 [ns/op] 
dynamic 1000000000 [ops]  1750.0 [s]  1.75 [ns/op] 

So - I think processor or JIT is already clever enought that ther is no need to optimize this function passing anyhow.

NOTE: - static mean solution without passing function ( I just inline operations manually into the loop), - dynamic mean solution when I use passing function as dynamic object instance (not static class). It seem that JIT understand that there are no dynamic data inside the class and so it resolve it in compile time anyway.

so my dynamic solution is just simple :

public class Function2 {
  public float eval(float a, float b){ return Float.NaN; }  
}

public class FAdd extends Function2 {
  public float eval(float a, float b){ return a+b; }
}

public class FMult extends Function2 {
  public float eval(float a, float b){ return a*b; }  
}

public void arrayOp( float [] a, float [] b, float [] out, Function2 func ){
  for (int i=0; i<a.length; i++){     out[i] = func.eval( a[i], b[i] );   }
}

final int m = 100;
final int n = 10000000;
float t1,t2;
float [] a,b, out;
a = new float[n];   b = new float[n];   out = new float[n];
t1 = millis();
Function2 func = new FMult(); 
for (int i=0;i<m;i++) arrayOp( a,b, out, func );
t2 = millis();
println( " dynamic " +(n*m)+" [ops]  "+(t2-t1)+" [s]  "+ 1000000*((t2-t1)/(n*m))+" [ns/op] " );