2

ここでは、いくつかの関数を受け入れて境界のセットを統合するために作成した Simpson 統合コードを使用した簡単な演習を行います。

import numpy as np
def simps(f, a, b, N):
    #N should be even
    dx = (b - a) / N
    x = np.linspace(a, b, N + 1)
    y = f(x)
    w = np.ones_like(y)
    w[2:-1:2] = 2.
    w[1::2]   = 4.
    S = dx / 3 * np.einsum("i...,i...",w,y)
    return S

def funcN(x):
    return np.stack([x**(i/10) * np.exp(-x) for i in range(200)],axis=1)

a = np.arange(0,10,0.1)
b = a+0.05

私は CPU デバイスを使用しており、Int(f_i, a_j,b_j) i:0-199 および j:0-99 に対応する 200 x 100 の数値配列を取得します。

%timeit simps(funcN,a,b, 512)

ループあたり 1.13 秒 ± 27.4 ミリ秒 (7 回の実行の平均値 ± 標準偏差、各ループ 1 回)

次の JAX/JIT バージョンを検討してください。

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from functools import partial
from jax.config import config
config.update("jax_enable_x64", True)   #numpy by default is in double precision

@partial(jit, static_argnums=(0,3))
def jax_simps(f, a,b, N):
    dx = (b - a) / N
    x = jnp.linspace(a, b, N + 1)
    y = f(x)
    w = jnp.ones_like(y)
    w = w.at[2:-1:2].set(2.)
    w = w.at[1::2].set(4.)
    S = dx / 3. * jnp.einsum('i...,i...',w,y)
    return S

@jit
def jax_funcN(x):
    return jnp.stack([x**(i/10) * jnp.exp(-x) for i in range(200)],axis=1)

ja = jnp.arange(0,10,0.1)
jb = ja+0.05

#warm up
jax_simps(jax_funcN,ja,jb, 512).block_until_ready() 

%timeit jax_simps(jax_funcN,ja,jb, 512).block_until_ready() 

2 つのコード (純粋な Numpy と JAX/JIT) が同じ結果をもたらすことを確認しました。最大相対誤差は 8.10^-16 のオーダーです。

今、私はループごとに次のタイミング933ミリ秒±51.4ミリ秒を得ました(7回の実行の平均±標準偏差、それぞれ1ループ)

これは純粋な Numpy に非常に近いものです。たまたま非常に効率的な純粋な Numpy コードを作成したことがありますか? または、間違った方法で JAX/JIT をコーディングしましたか?

(注: Google コラボ K80 GPU を使用すると、JAX/JIT のタイミングがループごとに 7.19 ミリ秒に低下し、純粋な Numpy が 1 秒/ループのレベルに維持されます)

4

1 に答える 1