Fast single dispatch to get around multiple dispatch at runtime

447 Views Asked by At

When type inference falters (::Any in @code_warntype printout), my understanding is that function calls are dynamically dispatched. In other words, at run-time, the arguments' types are checked to find the specialization (MethodInstance) for the concrete argument types. Needing to do this at run-time instead of compile-time incurs performance costs.

(EDIT: originally, I said "multiple dispatch finds the fitting method" between the type-checking and specialization-finding, but I don't actually know if this part happens at runtime. It seems that it only needs to happen if no valid specialization exists and one needs to be compiled.)

In cases where only one argument's concrete type needs to be checked, is it possible to do a faster dynamic single dispatch instead, like in some sort of lookup table of specializations? I just can't find a way to access and call MethodInstances as if they were functions.

When it comes to altering dispatch or specialization, I thought of invoke and @nospecialize. invoke looks like it might skip right to a specified method, but checking multiple argument types and specialization must still happen. @nospecialize doesn't skip any part of the dispatch process, just results in different specializations.

EDIT: A minimal example with comments that hopefully describe what I'm talking about.

struct Foo end
struct Bar end

#   want to dispatch only on 1st argument
#          still want to specialize on 2nd argument
baz(::Foo, ::Integer) = 1
baz(::Foo, ::AbstractFloat) = 1.0
baz(::Bar, ::Integer) = 1im
baz(::Bar, ::AbstractFloat) = 1.0im

x = Any[Foo(), Bar(), Foo()]

# run test1(x, 1) or test1(x, 1.0)
function test1(x, second)
  #   first::Any in @code_warntype printout
  for first in x
    # first::Any requires dynamic dispatch of baz
    println(baz(first, second))
    # Is it possible to only dispatch -baz- on -first- given
    # the concrete types of the other arguments -second-?
  end
end
1

There are 1 best solutions below

7
cbk On

The easiest way to do what you ask is to simply not dispatch on the second argument (by not specifying a type assertion on the second variable specific enough to trigger dispatch), and instead specialize with an if statement within your function. For example:

struct Foo end
struct Bar end

# Note lack of type assertion on second variable. 
# We could also write `baz(::Foo, n::Number)` for same effect in this case, 
# but type annotations have no performance benefit in Julia if you're not 
# dispatching on them anyways.
function baz(::Foo, n) 
    if isa(n, Integer)
        1
    elseif isa(n, AbstractFloat)
        1.0
    else
        error("unsupported type")
    end
end

function baz(::Bar, n)
    if isa(n, Integer)
        1im
    elseif isa(n, AbstractFloat)
        1.0im
    else
        error("unsupported type")
    end
end

Now, this will do what you want

julia> x = Any[Foo(), Bar(), Foo()]
3-element Vector{Any}:
 Foo()
 Bar()
 Foo()

julia> test1(x, 1)
1
0 + 1im
1

julia> test1(x, 1.0)
1.0
0.0 + 1.0im
1.0

and since this effectively manually picks only two cases to specialize out of all the possible types to specialize on, I could imagine scenarios where this sort of technique has performance benefits (though, of course, it goes without saying in Julia that generally even better would be to find and eliminate the source of the type instability in the first place if at all possible).

However, it is critically important in the context of this question as written to point out that that even though we have eliminated dispatch on the second argument of the function, these baz functions may still have poor performance if the first argument (i.e., the one you are dispatching on) is type-unstable – as is the case in the question as written because of the use of an Array{Any}.

Instead, try to use an array with at least some type constraint. Ex:

julia> function test2(x, second)
           s = 1+1im
           for first in x
               s += baz(first, second)
           end
           s
       end
test2 (generic function with 1 method)

julia> using BenchmarkTools

julia> x = Any[Foo(), Bar(), Foo()];

julia> @benchmark test2($x, 1)
BenchmarkTools.Trial: 10000 samples with 998 evaluations.
 Range (min … max):  13.845 ns … 71.554 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     13.869 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   15.397 ns ±  3.821 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▅  ▃ ▄  ▄      ▄       ▄                                 ▃ ▁
  ██▇▆█▇██▄█▇▇▄▃▁▁██▁▃▃▁▁▃██▃▁▃▁▁▄▃▃▃▆▆▅▆▆▅▅▄▁▁▄▃▃▃▁▃▁▄▁▁▃▄▄█ █
  13.8 ns      Histogram: log(frequency) by time      30.2 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> x = Union{Foo,Bar}[Foo(), Bar(), Foo()];

julia> @benchmark test2($x, 1)
BenchmarkTools.Trial: 10000 samples with 1000 evaluations.
 Range (min … max):  4.654 ns … 62.311 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     4.707 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   5.471 ns ±  1.714 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  █▂▂▃▄ ▃  ▄▁    ▄▂      ▅▁                               ▁▄ ▁
  ███████▁▁██▁▁▁▁██▁▁▁▁▁▁██▁▁▁▄▁▃▁▃▁▁▁▁▃▁▁▁▁▃▁▃▃▁▁▁▁▃▁▁▁▁▁██ █
  4.65 ns      Histogram: log(frequency) by time     10.2 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.