0

I'd like to use multiprocessing in a rescource-heavy computation in a code I write, as shown in this watered-down example:

import numpy as np
import multiprocessing as multiproc

def function(r, phi, z, params):
    """returns an array of the timepoints and the corresponding values 
       (demanding computation in actual code, with iFFT and stuff)"""
    times = np.array([1.,2.,3.])
    tdependent_vals = r + z * times + phi
    return np.array([times, tdependent_vals])

def calculate_func(rmax, zmax, phi, param):
    rvals = np.linspace(0,rmax,5)
    zvals = np.linspace(0,zmax,5)
    for r in rvals:
        func_at_r = lambda z: function(r, phi, z, param)[1]
        with multiproc.Pool(2) as pool:
             fieldvals = np.array([*pool.map(func_at_r, zvals)])
             print(fieldvals) #for test, it's actually saved in a numpy array

calculate_func(3.,4.,5.,6.)

If I run this, it fails with

AttributeError: Can't pickle local object 'calculate_func.<locals>.<lambda>'

What I think the reason is, according to the documentation, only top-level defined functions can be pickled, and my in-function defined lambda can't. But I don't see any way I could make it a standalone function, at least without polluting the module with a bunch of top-level variables: the parameters are unknown before calculate_func is called, and they're changing at each iteration over rvals. This whole multiprocessing thing is very new to me, and I couldn't come up with an alternative. What would be the simplest working way to parallelize the loop over rvals and zvals?

Note: I used this answer as a starting point.

Neinstein
  • 699
  • 1
  • 7
  • 28

1 Answers1

1

This probably isn't the best answer for this, but it's an answer, so please no hate :)

You can just write a top level wrapper function that can be serialized and have it execute functions... This is kinda like function inception a bit but I solved a similar problem in my code like this.

Here is a brief example

def wrapper(arg_list, *args):
    func_str = arg_list[0]
    args = arg_list[1]
    code = marshal.loads(base64.b64decode(func_str.data))
    func = types.FunctionType(code, globals(), "wrapped_func")
    return func(*args)

def run_func(func, *args):
    func_str = base64.b64encode(marshal.dumps(func.__code__, 0))
    arg_list = [func_str, args]
    with mp.Pool(2) as pool:
        results = pool.map(wrapper, arg_list)
    return results
sehafoc
  • 736
  • 5
  • 8