次のコードを使用して、multiprocessing.Pool と同等にします。まだ広範囲にテストされていませんが、問題なく動作するようです:
from functools import partial
function = partial(...) # Store all fixed parameters this way if needed
if use_MPI:
arguments = range(num_runs)
run_data = None
# mpi4py
comm = MPI.COMM_SELF.Spawn(sys.executable, args=['MPI_slave.py'], maxprocs=num_runs) # Init
comm.bcast(function, root=MPI.ROOT) # Equal for all processes
comm.scatter(arguments, root=MPI.ROOT) # Different for each process
comm.Barrier() # Wait for everything to finish...
run_data = comm.gather(run_data, root=MPI.ROOT) # And gather everything up
else:
# multiprocessing
p = Pool(multiprocessing.cpu_count())
run_data = p.map(function, range(num_runs))
次に、別のファイル 'MPI_slave.py' を使用します。
from mpi4py import MPI
# import the function you actually pass to this file here!!!
comm = MPI.COMM_SELF.Get_parent()
size = comm.Get_size()
rank = comm.Get_rank()
def runSlaveRun():
function = None
options = None
# print("Process {}/{} reporting for duty!".format(rank, size))
function = comm.bcast(function, root=0)
arguments = comm.scatter(options, root=0)
results = function(arguments)
comm.Barrier()
comm.gather(results, root=0)
comm.Disconnect()
if __name__ == '__main__':
runSlaveRun()