Parallelisation of first-to-survive-wins loop

113 Views Asked by At

I have a problem, which, when simplified:

  1. has a loop which samples new points
  2. evaluates them with a complex/slow function
  3. accepts them if the value is above an ever-increasing threshold.

Here is example code for illustration:

from numpy.random import uniform
from time import sleep

def userfunction(x):
    # do something complicated
    # but computation always takes takes roughly the same time
    sleep(1) # comment this out if too slow
    xnew = uniform() # in reality, a non-trivial function of x
    y = -0.5 * xnew**2
    return xnew, y

x0, cur = userfunction([])
x = [x0] # a sequence of points

while cur < -2e-16:
    # this should be parallelised

    # search for a new point higher than a threshold
    x1, next = userfunction(x)
    if next <= cur:
        # throw away (this branch is taken 99% of the time)
        pass
    else:
        cur = next
        print cur
        x.append(x1) # note that userfunction depends on x

print x

I want to parallelise this (e.g. across a cluster), but the problem is that I need to terminate the other workers when a successful point has been found, or at least inform them of the new x (if they manage to get above the new threshold with an older x, the result is still acceptable). As long as no point has been successful, I need the workers repeat.

I am looking for tools/frameworks which can handle this type of problem, in any scientific programming language (C, C++, Python, Julia, etc., no Fortran please).

Can this be solved with MPI semi-elegantly? I don't understand how I can inform/interrupt/update workers with MPI.

Update: added code comments to say most tries are unsuccessful and do not influence the variable userfunction depends on.

2

There are 2 best solutions below

6
On

if userfunction() does not take too long, then here is an option that qualifies for "MPI semi-elegantly"

in order to keep thing simple, let's assume rank 0 is only an orchestrator and does not compute anything.

on rank 0

cur = 0
x = []
while cur < -2e-16:
    MPI_Recv(buf=cur+x1, src=MPI_ANY_SOURCE)
    x.append(x1)
    MPI_Ibcast(buf=cur+x, root=0, request=req)
    MPI_Wait(request=req)

on rank != 0

x0, cur = userfunction([])
x = [x0] # a sequence of points

while cur < -2e-16:
    MPI_Ibcast(buf=newcur+newx, root=0, request=req
    # search for a new point higher than a threshold
    x1, next = userfunction(x)
    if next <= cur:
        # throw away (this branch is taken 99% of the time)
        MPI_Test(request=ret, flag=found)
        if found:
            MPI_Wait(request)   
    else:
        cur = next
        MPI_Send(buffer=cur+x1, dest=0)
        MPI_Wait(request)

extra logic is needed to correctly handle - rank 0 does computation too - several ranks find the solution at the same time, subsequent messages must be consumed by rank 0

strictly speaking, a task is not "interrupted" when a solution is found on an other task. instead, each task check periodically if the solution was found by an other task. so there is a delay between the time a solution if found somewhere and all tasks stop looking for solutions, but if userfunction() does not take "too long", this looks very acceptable to me.

0
On

I solved it roughly with the following code.

This transmits only curmax at the moment, but one can send the other array with a second broadcast+tag.

import numpy
import time

from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

import logging
logging.basicConfig(filename='mpitest%d.log' % rank,level=logging.DEBUG)
logFormatter = logging.Formatter("[%(name)s %(levelname)s]: %(message)s")
consoleHandler = logging.StreamHandler()
consoleHandler.setFormatter(logFormatter)
consoleHandler.setLevel(logging.INFO)
logging.getLogger().addHandler(consoleHandler)

log = logging.getLogger(__name__)

if rank == 0:
    curmax = numpy.random.random()
    seq = [curmax]
    log.info('%d broadcasting starting value %f...' % (rank, curmax))
    comm.Ibcast(numpy.array([curmax]))

    was_updated = False
    while True:
        # check if news available
        status = MPI.Status()
        a_avail = comm.iprobe(source=MPI.ANY_SOURCE, tag=12, status=status)
        if a_avail:
            sugg = comm.recv(source=status.Get_source(), tag=12)
            log.info('%d received new limit from %d: %s' % (rank, status.Get_source(), sugg))
            if sugg < curmax:
                curmax = sugg
                seq.append(curmax)
                log.info('%d updating to %s' % (rank, curmax))
                was_updated = True
            else:
                # ignore
                pass
        # check if next message is already waiting:
        if comm.iprobe(source=MPI.ANY_SOURCE, tag=12):
            # consume it first before broadcasting outdated info
            continue

        if was_updated:
            log.info('%d broadcasting new limit %f...' % (rank, curmax))
            comm.Ibcast(numpy.array([curmax]))
            was_updated = False
        else:
            # no message waiting for us and no broadcast done, so pause
            time.sleep(0.1)
        print

    print data, rank
else:
    log.info('%d waiting for root to send us starting value...' % (rank))
    nextmax = numpy.empty(1, dtype=float)
    comm.Ibcast(nextmax).Wait()

    amax = float(nextmax)
    numpy.random.seed(rank)
    update_req = comm.Ibcast(nextmax)
    while True:
        a = numpy.random.uniform()
        if a < amax:
            log.info('%d found new: %s, sending to root' % (rank, a))
            amax = a
            comm.isend(a, dest=0, tag=12)
        s = update_req.Get_status()
        #log.info('%d bcast status: %s' % (rank, s))
        if s:
            update_req.Wait()
            log.info('%d receiving new limit from root, %s' % (rank, nextmax))
            amax = float(nextmax)
            update_req = comm.Ibcast(nextmax)