6

Using numba.jit to speed up right-hand-side calculations for odeint from scipy.integrate works fine:

from scipy.integrate import ode, odeint
from numba import jit

@jit
def rhs(t, X):
    return 1

X = odeint(rhs, 0, np.linspace(0, 1, 11))

However using integrate.ode like this:

solver = ode(rhs)
solver.set_initial_value(0, 0)
while solver.successful() and solver.t < 1:
    solver.integrate(solver.t + 0.1)

produces the following error with the decorator @jit:

capi_return is NULL
Call-back cb_f_in_dvode__user__routines failed.
Traceback (most recent call last):
  File "sandbox/numba_cubic.py", line 15, in <module>
    solver.integrate(solver.t + 0.1)
  File "/home/pgermann/Software/anaconda3/lib/python3.4/site-packages/scipy/integrate/_ode.py", line 393, in integrate
    self.f_params, self.jac_params)
  File "/home/pgermann/Software/anaconda3/lib/python3.4/site-packages/scipy/integrate/_ode.py", line 848, in run
    y1, t, istate = self.runner(*args)
TypeError: not enough arguments: expected 2, got 1

Any ideas how to overcome this?

germannp
  • 301
  • 3
  • 12

2 Answers2

1

I do not know a reason or solution, however in this case Theano helped a lot to speed up the calculation. Theano essentially compiles numpy expressions, so it only helps when you can write the rhs as expression of multi-dimensional arrays (while jit knows for and friends). It also knows some algebra and optimizes the calculation.

Besides Theano can compile for the GPU (which was my reason to try numba.jit in the first place). However using the GPU turned out to only improve performance for huge systems (maybe one million equations) due to the overhead.

germannp
  • 301
  • 3
  • 12
1

You can use a wrapper function, but I think it will not improve your performance for small rhs functions.

@jit(nopython=True)
def rhs(t, X):
    return 1

def wrapper(t, X):
    return rhs(t, X)

solver = ode(wrapper)
solver.set_initial_value(0, 0)
while solver.successful() and solver.t < 1:
solver.integrate(solver.t + 0.1)
TheIdealis
  • 515
  • 3
  • 10