2
def f1(x):
    for i in range(1, 100):
        x *= 2
        x /= 3.14159
        x *= i**.25
    return x

def f2(x):
    for i in range(1, 100):
        x *= 2 / 3.14159 * i**.25
    return x

Both functions compute exactly the same, but f1 takes 3x longer to do so, even with @numba.njit. Can Python be made to recognize the equivalence in compilation, just like it optimizes in other ways seen with dis by e.g. throwing out unused assignments?

Note, I'm aware floating point arithmetic cares about order, so the two functions may output slightly differently, but if anything more separate edits to array values are less accurate, so it'd be a 2-in-1 optimization.


x = np.random.randn(10000, 1000)
%timeit f1(x.copy())        # 2.68 s ± 50.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f2(x.copy())        # 894 ms ± 36.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit njit(f1)(x.copy())  # 2.59 s ± 65.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit njit(f2)(x.copy())  # 901 ms ± 41.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
OverLordGoldDragon
  • 14,529
  • 6
  • 35
  • 69

2 Answers2

1

Using numba.jit is probably the best optimization you will get for the moment for this kind of function. You may also want to try pypy and do some benchmark comparisons.

Although, I want to point out why the two functions are not equivalent and so you should not expect f1 to be reduced to f2.

The order of operation goes as follows for f1:

x1 = (x * 2)            # First binary operation
x2 = (x1 / 3.14159      # Second binary operation
x3 = x2 * (i ** 0.25)   # Third and fourth binary operation

# Order: Multiplication, division, exponent, multiplication

Which is not the same as for f2:

x *= ((2 / 3.14159) * (i ** 0.25))
#  ^     ^          ^     ^
#  |     |          |     |
#  4     1          3     2

# Order: Division, exponent, multiplication, multiplication

Since floating-point arithmetic is not associative, those may not yield the same result. For that reason, it would be wrong for a compiler or interpreter to do the optimization you expected unless it is meant to optimize floating-point precision.

I am not aware of a Python tool which is meant to do this specific kind of optimization.

Olivier Melançon
  • 19,112
  • 3
  • 34
  • 61
  • I'm aware regarding associativity, but if anything _more_ modifications to array values are _less_ accurate, so it'd be a two-in-one optimization. – OverLordGoldDragon Nov 03 '20 at 20:26
  • @OverLordGoldDragon I am not sure to understand what you mean by modifications to array values being less accurate – Olivier Melançon Nov 03 '20 at 20:28
  • For example, roundoff error accrues over repeated e.g. addition, whereas adding all at once involves only one roundoff. Typical issue in machine learning. – OverLordGoldDragon Nov 03 '20 at 20:29
  • 1
    @OverLordGoldDragon I see. Then back to you problem: if you benchmarked f2 as being faster, why not just use it instead of f1? This is not the kind of optimization I think you should expect a Python compiler to do today... hopefully in the future we will find some ways. – Olivier Melançon Nov 03 '20 at 20:34
  • Readability problems; long line plus in scientific computing context, introduces illogical flow of ideas. Basically like an ugly vectorization vs intuitive but slow for-loop. Nothing one can't compensate with enough inline comments, but would be better if the compiler took care of it instead. – OverLordGoldDragon Nov 03 '20 at 20:41
  • [Quick example](https://i.stack.imgur.com/6XYN7.png), but it can get much worse – OverLordGoldDragon Nov 03 '20 at 20:48
  • 1
    @OverLordGoldDragon I see. The problem is even ahrder in that case because psihfn could dynamically be assigned another function. This makes static analysis terribly complexe in Python. – Olivier Melançon Nov 03 '20 at 21:15
  • Right, not the best example – OverLordGoldDragon Nov 03 '20 at 21:53
0

Probably can't do it with the jit. I have tried fastmath and nogil kwarg specified in the api: https://numba.pydata.org/numba-doc/latest/reference/jit-compilation.html

f0 is still slightly slower than f1 after getting rid of overflow or denormal number. plot

from timeit import default_timer as timer
import numpy as np
import matplotlib.pyplot as plt
import numba as nb


def f0(x):
    for i in range(1, 1000):
        x *= 3.000001
        x /= 3
    return x


def f1(x):
    for i in range(1, 1000):
        x *= 3.000001 / 3
    return x


def timing(f, **kwarg):
    x = np.ones(1000, dtype=np.float32)
    times = []
    n_iter = list(range(100, 1000, 100))
    f2 = nb.njit(f, **kwarg)
    for i in n_iter:
        print(i)
        s = timer()
        for j in range(i):
            f2(x)
        e = timer()
        times.append(e - s)
    print(x)
    m, b = np.polyfit(n_iter, times, 1)
    return times, m, b, n_iter


def main():
    results = []
    for fastmath in [True, False]:
        for i, f in enumerate([f0, f1]):
            kwarg = {
                "fastmath": fastmath,
                "nogil": True
            }
            r1, m, b, n_iter = timing(f, **kwarg)
            label = "f%d with %s" % (i, kwarg)
            plt.plot(n_iter, r1, label=label)
            results.append((m, b, label))
    for m, b, kwarg in results:
        print(m * 1e5, b, kwarg)
    plt.legend(loc="upper left")
    plt.xlabel("n iterations")
    plt.ylabel("timing")
    plt.show()
    plt.close()


if __name__ == '__main__':
    main()

Crawl Cycle
  • 189
  • 7