I have a function that parallizes another function via multiprocessing pool which takes a dictionary as input. I would expect that the code below just prints the number from 0 to 32. However, the result shows that there are many numbers being printed more than once.
Anybody an idea?
import multiprocessing as mp
import numpy as np
import functools
def test(name, t_dict):
t_dict['a'] = name
return t_dict
def mp_func(func, iterator ,**kwargs):
f_args = functools.partial(func, **kwargs)
pool = mp.Pool(mp.cpu_count())
res = pool.map(f_args, iterator)
pool.close()
return res
mod =dict()
m =33
res = mp_func(func=test, iterator=np.arange(m), t_dict=mod)
for di in res:
print(di['a'])
The problem is that
t_dictis passed as part of the partial functionf_args. Partial functions are instances of<class 'functools.partial'>. When you create the partial, it gets a reference totestand the empty dictionary inmod. Every time you callf_args, that one dictionary on the partial object is modified. This is easier to spot with a list in a single process.When you
pool.map(f_args, iterator),f_argsis pickled and sent to each subprocess to be the worker. So, each subprocess has a unique copy of the dictionary that will be updated for every iterated value the subprocess happens to get.For efficiency, multiprocessing will chunk data. That is, each subprocess is handed a list of iterated values that it will process into a list of responses to return as a group. But since each response is referencing the same single dict, when the chunk is returned to the parent all of the responses only hold the final value set. If
0, 1, 2were processed, the return is2, 2, 2.The solution will depend on your data. Its expensive to pass data back and forth between pool process and the parent, so ideally the data is generated completely in the worker. In this case, ditch the
partialand have the worker create the dict.Its likely your situation is more complicated than this.