Parallel simd MUCH slower than serial simd in Julia

235 Views Asked by At

Summary: Scroll down for reproducible example which should run-from-scratch in Julia if you have the packages specified in the using lines. (Note: the ODE has a complex, re-usable structure which is specified in a Gist which is downloaded/included by the script.)

Background: I have to repeatedly solve a large system of ODEs for different initial conditions vectors. In the example below, it is 127 states/ODEs, but it could easily be 1000-2000. I will have to run these 100s-1000s of times for inference, so speed is essential.

The Puzzle: The short version is that, for the serial functions, the @simd version is much faster than the "plain", non-@simd version. But for the parallel versions, the @simd version is much slower -- plus, in this case, the answer, sum_of_solutions, is variable and wrong.

I have this set up so that Julia is started with JULIA_NUM_THREADS=auto julia, in my case this creates 8 cores for 8 threads. Then, I make sure I never have more than 8 jobs spawned at once.

The different calculation times: (runtime, then sum_of_ODE_solutions)

# Output is (runtime, sum_of_solutions)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
# (duration, sum_of_solutions)
# (1.1, 8.731365050398926)
# (0.878, 8.731365050398926)
# (0.898, 8.731365050398926)

serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
# (duration, sum_of_solutions)
# (0.046, 8.731365050398928)
# (0.042, 8.731365050398928)
# (0.046, 8.731365050398928)

parallel_with_plain_v5(tspan, p_Ds_v7, solve_results2; number_of_solves=number_of_solves)
# Faster than serial plain version
# (duration, sum_of_solutions)
# (0.351, 8.731365050398926)
# (0.343, 8.731365050398926)
# (0.366, 8.731365050398926)

parallel_with_simd_v7(tspan, p_Ds_v7, solve_results2; number_of_solves=number_of_solves)
# Dramatically slower than serial simd version, plus wrong sum_of_solutions
# (duration, sum_of_solutions)
# (136.966, 9.61313614002137)
# (141.843, 9.616688089683372)

As you can see, while serial @simd gets the calculation speed down to 0.046 seconds, and while parallel plain is 2.5 times faster than serial plain, when I combine parallelization with the @simd function I get runtimes of 140 seconds, and with variable & wrong answers to boot! Literally the only difference between the two parallelizng functions is using core_op_plain versus core_op_simd for the core ODE solving operation.

It seems like @simd and @spawn must be conflicting somehow? I have the parallelizing function set up to never employ more than the number of CPU threads available. (8 on my machine.)

I am still learning Julia so there is the chance that some smallish change could isolate the @simd calculations and prevent conflicts across threads (if that is what is happening). Any help is very much appreciated!

PS: Reproducible Example. The code below should provide a reproducible example on any Julia session running multiple cores. I also have my versioninfo() etc.:

versioninfo()
notes="""
My setup is:
Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin21.4.0)
  CPU: Intel(R) Xeon(R) CPU E5-2697 v2 @ 2.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, ivybridge)
"""


# Startup notes
notes="""
# "If $JULIA_NUM_THREADS is set to auto, then the number of threads will be set to the number of CPU threads."
JULIA_NUM_THREADS=auto julia --startup-file=no
Threads.nthreads(): 8 # Number of CPU threads
"""


using LinearAlgebra     # for "I" in: Matrix{Float64}(I, 2, 2)
                                            # https://www.reddit.com/r/Julia/comments/9cfosj/identity_matrix_in_julia_v10/
using Sundials              # for CVODE_BDF
using Statistics            # for mean(), max()
using DataFrames  # for e.g. DataFrame()
using Dates                     # for e.g. DateTime, Dates.now()
using DifferentialEquations # for ODEProblem
using BenchmarkTools    # for @benchmark
using Distributed           # for workers


# Check that you have multiple threads
numthreads = Base.Threads.nthreads()

# Download & include the pre-saved model structure/rates (all precalculated for speed; 1.8 MB)
#include("/GitHub/BioGeoJulia.jl/test/model_p_object.jl")
url = "https://gist.githubusercontent.com/nmatzke/ed99ab8f5047794eb25e1fdbd5c43b37/raw/b3e6ddff784bd3521d089642092ba1e3830699c0/model_p_object.jl"
download(url,  "model_p_object.jl")
include("model_p_object.jl")

# Load the ODE functions
url = "https://gist.githubusercontent.com/nmatzke/f116258c78bd43ab7a448f07c4290516/raw/24a210261fd2e090b8ed27bc64a59a1ff9ec62cd/simd_vs_spawn_setup_v2.jl"
download(url,  "simd_vs_spawn_setup_v2.jl")
include("simd_vs_spawn_setup_v2.jl")

#include("/GitHub/BioGeoJulia.jl/test/simd_vs_spawn_setup_v2.jl")
#include("/GitHub/BioGeoJulia.jl/test/simd_vs_spawn_setup_v3.jl")

# Load the pre-saved model structure/rates (all precalculated for speed; 1.8 MB)
p_Es_v5 = load_ps_127();



# Set up output object
numstates = 127
number_of_solves = 10

solve_results1 = Array{Float64, 2}(undef, number_of_solves, numstates)
solve_results1 .= 0.0
solve_results2 = Array{Float64, 2}(undef, number_of_solves, numstates)
solve_results2 .= 0.0
length(solve_results1)
length(solve_results1[1])
sum(sum.(solve_results1))


# Precalculate the Es for use in the Ds
Es_tspan = (0.0, 60.0)
prob_Es_v7 = DifferentialEquations.ODEProblem(Es_v7_simd_sums, p_Es_v5.uE, Es_tspan, p_Es_v5);
sol_Es_v7 = solve(prob_Es_v7, CVODE_BDF(linear_solver=:GMRES), save_everystep=true, 
abstol=1e-12, reltol=1e-9);

p_Ds_v7 = (n=p_Es_v5.n, params=p_Es_v5.params, p_indices=p_Es_v5.p_indices, p_TFs=p_Es_v5.p_TFs, uE=p_Es_v5.uE, terms=p_Es_v5.terms, sol_Es_v5=sol_Es_v7);


# Set up ODE inputs
u = collect(repeat([0.0], numstates));
u[2] = 1.0
du = similar(u)
du .= 0.0
p = p_Ds_v7;
t = 1.0

# ODE functions to integrate (single-step; ODE solvers will run this many many times)
@time Ds_v5_tmp(du,u,p,t)
@time Ds_v5_tmp(du,u,p,t)
@time Ds_v7_simd_sums(du,u,p,t)
@time Ds_v7_simd_sums(du,u,p,t)

#@btime Ds_v5_tmp(du,u,p,t)
# 7.819 ms (15847 allocations: 1.09 MiB)

#@btime Ds_v7_simd_sums(du,u,p,t)
# 155.858 μs (3075 allocations: 68.66 KiB)



tspan = (0.0, 1.0)
prob_Ds_v7 = DifferentialEquations.ODEProblem(Ds_v7_simd_sums, p_Ds_v7.uE, tspan, p_Ds_v7);

sol_Ds_v7 = solve(prob_Ds_v7, CVODE_BDF(linear_solver=:GMRES), save_everystep=false, abstol=1e-12, reltol=1e-9);

# This is the core operation; plain version (no @simd)
function core_op_plain(u, tspan, p_Ds_v7)
    prob_Ds_v5 = DifferentialEquations.ODEProblem(Ds_v5_tmp, u.+0.0, tspan, p_Ds_v7);

    sol_Ds_v5 = solve(prob_Ds_v5, CVODE_BDF(linear_solver=:GMRES), save_everystep=false, abstol=1e-12, reltol=1e-9);
    return sol_Ds_v5
end


# This is the core operation; @simd version
function core_op_simd(u, tspan, p_Ds_v7)
    prob_Ds_v7 = DifferentialEquations.ODEProblem(Ds_v7_simd_sums, u.+0.0, tspan, p_Ds_v7);

    sol_Ds_v7 = solve(prob_Ds_v7, CVODE_BDF(linear_solver=:GMRES), save_everystep=false, abstol=1e-12, reltol=1e-9);
    return sol_Ds_v7
end

@time core_op_plain(u, tspan, p_Ds_v7);
@time core_op_plain(u, tspan, p_Ds_v7);
@time core_op_simd(u, tspan, p_Ds_v7);
@time core_op_simd(u, tspan, p_Ds_v7);


function serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=10)
    start_time = Dates.now()
    for i in 1:number_of_solves
        # Temporary u
        solve_results1[i,:] .= 0.0
        
        # Change the ith state from 0.0 to 1.0
        solve_results1[i,i] = 1.0
        solve_results1

        sol_Ds_v7 = core_op_plain(solve_results1[i,:], tspan, p_Ds_v7)
        solve_results1[i,:] .=  sol_Ds_v7.u[length(sol_Ds_v7.u)]
    #   print("\n")
    #   print(round.(sol_Ds_v7[length(sol_Ds_v7)], digits=3))
    end
    
    end_time = Dates.now()
    duration = (end_time - start_time).value / 1000.0
    sum_of_solutions = sum(sum.(solve_results1))
    return (duration, sum_of_solutions)
end


function serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=10)
    start_time = Dates.now()
    for i in 1:number_of_solves
        # Temporary u
        solve_results1[i,:] .= 0.0
        
        # Change the ith state from 0.0 to 1.0
        solve_results1[i,i] = 1.0
        solve_results1

        sol_Ds_v7 = core_op_simd(solve_results1[i,:], tspan, p_Ds_v7)
        solve_results1[i,:] .=  sol_Ds_v7.u[length(sol_Ds_v7.u)]
    #   print("\n")
    #   print(round.(sol_Ds_v7[length(sol_Ds_v7)], digits=3))
    end
    
    end_time = Dates.now()
    duration = (end_time - start_time).value / 1000.0
    sum_of_solutions = sum(sum.(solve_results1))
    return (duration, sum_of_solutions)
end

# Output is (runtime, sum_of_solutions)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
# (duration, sum_of_solutions)
# (1.1, 8.731365050398926)
# (0.878, 8.731365050398926)
# (0.898, 8.731365050398926)

serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
# (duration, sum_of_solutions)
# (0.046, 8.731365050398928)
# (0.042, 8.731365050398928)
# (0.046, 8.731365050398928)

using Distributed

function parallel_with_plain_v5(tspan, p_Ds_v7, solve_results2; number_of_solves=10)
    start_time = Dates.now()
    number_of_threads = Base.Threads.nthreads()
    curr_numthreads = Base.Threads.nthreads()
        
    # Individual ODE solutions will occur over different timeperiods,
    # initial values, and parameters.  We'd just like to load up the 
    # cores for the first jobs in the list, then add jobs as earlier
    # jobs finish.
    tasks = Any[]
    tasks_started_TF = Bool[]
    tasks_fetched_TF = Bool[]
    task_numbers = Any[]
    task_inc = 0
    are_we_done = false
    current_running_tasks = Any[]
    
    # List the tasks
    for i in 1:number_of_solves
        # Temporary u
        solve_results2[i,:] .= 0.0
        
        # Change the ith state from 0.0 to 1.0
        solve_results2[i,i] = 1.0

        task_inc = task_inc + 1
        push!(tasks_started_TF, false) # Add a "false" to tasks_started_TF
        push!(tasks_fetched_TF, false) # Add a "false" to tasks_fetched_TF
        push!(task_numbers, task_inc)
    end
    
    # Total number of tasks
    num_tasks = length(tasks_fetched_TF)

    iteration_number = 0
    while(are_we_done == false)
        iteration_number = iteration_number+1
        
        # Launch tasks when thread (core) is available
        for j in 1:num_tasks
            if (tasks_fetched_TF[j] == false)
                if (tasks_started_TF[j] == false) && (curr_numthreads > 0)
                    # Start a task
                    push!(tasks, Base.Threads.@spawn core_op_plain(solve_results2[j,:], tspan, p_Ds_v7));
                    curr_numthreads = curr_numthreads-1;
                    tasks_started_TF[j] = true;
                    push!(current_running_tasks, task_numbers[j])
                end
            end
        end
        
        # Check for finished tasks
        tasks_to_check_TF = ((tasks_started_TF.==true) .+ (tasks_fetched_TF.==false)).==2
        if sum(tasks_to_check_TF .== true) > 0
            for k in 1:sum(tasks_to_check_TF)
                if (tasks_fetched_TF[current_running_tasks[k]] == false)
                    if (istaskstarted(tasks[k]) == true) && (istaskdone(tasks[k]) == true)
                        sol_Ds_v7 = fetch(tasks[k]);
                        solve_results2[current_running_tasks[k],:] .= sol_Ds_v7.u[length(sol_Ds_v7.u)].+0.0
                        tasks_fetched_TF[current_running_tasks[k]] = true
                        current_tasknum = current_running_tasks[k]
                        deleteat!(tasks, k)
                        deleteat!(current_running_tasks, k)
                        curr_numthreads = curr_numthreads+1;
                        print("\nFinished task #")
                        print(current_tasknum)
                        print(", current task k=")
                        print(k)
                        break # break out of this loop, since you have modified current_running_tasks
                    end
                end
            end
        end

        are_we_done = sum(tasks_fetched_TF) == length(tasks_fetched_TF)
        # Test for concluding the while loop
        are_we_done && break
    end # END while(are_we_done == false)

    end_time = Dates.now()
    duration = (end_time - start_time).value / 1000.0
    sum_of_solutions = sum(sum.(solve_results2))
    print("\n")
    return (duration, sum_of_solutions)
end


function parallel_with_simd_v7(tspan, p_Ds_v7, solve_results2; number_of_solves=10)
    start_time = Dates.now()
    number_of_threads = Base.Threads.nthreads()
    curr_numthreads = Base.Threads.nthreads()
        
    # Individual ODE solutions will occur over different timeperiods,
    # initial values, and parameters.  We'd just like to load up the 
    # cores for the first jobs in the list, then add jobs as earlier
    # jobs finish.
    tasks = Any[]
    tasks_started_TF = Bool[]
    tasks_fetched_TF = Bool[]
    task_numbers = Any[]
    task_inc = 0
    are_we_done = false
    current_running_tasks = Any[]
    
    # List the tasks
    for i in 1:number_of_solves
        # Temporary u
        solve_results2[i,:] .= 0.0
        
        # Change the ith state from 0.0 to 1.0
        solve_results2[i,i] = 1.0

        task_inc = task_inc + 1
        push!(tasks_started_TF, false) # Add a "false" to tasks_started_TF
        push!(tasks_fetched_TF, false) # Add a "false" to tasks_fetched_TF
        push!(task_numbers, task_inc)
    end
    
    # Total number of tasks
    num_tasks = length(tasks_fetched_TF)

    iteration_number = 0
    while(are_we_done == false)
        iteration_number = iteration_number+1
        
        # Launch tasks when thread (core) is available
        for j in 1:num_tasks
            if (tasks_fetched_TF[j] == false)
                if (tasks_started_TF[j] == false) && (curr_numthreads > 0)
                    # Start a task
                    push!(tasks, Base.Threads.@spawn core_op_simd(solve_results2[j,:], tspan, p_Ds_v7))
                    curr_numthreads = curr_numthreads-1;
                    tasks_started_TF[j] = true;
                    push!(current_running_tasks, task_numbers[j])
                end
            end
        end
        
        # Check for finished tasks
        tasks_to_check_TF = ((tasks_started_TF.==true) .+ (tasks_fetched_TF.==false)).==2
        if sum(tasks_to_check_TF .== true) > 0
            for k in 1:sum(tasks_to_check_TF)
                if (tasks_fetched_TF[current_running_tasks[k]] == false)
                    if (istaskstarted(tasks[k]) == true) && (istaskdone(tasks[k]) == true)
                        sol_Ds_v7 = fetch(tasks[k]);
                        solve_results2[current_running_tasks[k],:] .= sol_Ds_v7.u[length(sol_Ds_v7.u)].+0.0
                        tasks_fetched_TF[current_running_tasks[k]] = true
                        current_tasknum = current_running_tasks[k]
                        deleteat!(tasks, k)
                        deleteat!(current_running_tasks, k)
                        curr_numthreads = curr_numthreads+1;
                        print("\nFinished task #")
                        print(current_tasknum)
                        print(", current task k=")
                        print(k)
                        break # break out of this loop, since you have modified current_running_tasks
                    end
                end
            end
        end

        are_we_done = sum(tasks_fetched_TF) == length(tasks_fetched_TF)
        # Test for concluding the while loop
        are_we_done && break
    end # END while(are_we_done == false)

    end_time = Dates.now()
    duration = (end_time - start_time).value / 1000.0
    sum_of_solutions = sum(sum.(solve_results2))
    print("\n")
    return (duration, sum_of_solutions)
end

tspan = (0.0, 1.0)
parallel_with_plain_v5(tspan, p_Ds_v7, solve_results2; number_of_solves=number_of_solves)
# Faster than serial plain version
# (duration, sum_of_solutions)
# (0.351, 8.731365050398926)
# (0.343, 8.731365050398926)
# (0.366, 8.731365050398926)

parallel_with_simd_v7(tspan, p_Ds_v7, solve_results2; number_of_solves=number_of_solves)
# Dramatically slower than serial simd version
# (duration, sum_of_solutions)
# (136.966, 9.61313614002137)
# (141.843, 9.616688089683372)



Thanks again, Nick

0

There are 0 best solutions below