1

で行列をべき乗する 2 つの方法がありjnp = jax.numpyます。簡単なもの:

jnp.exp(-X/reg)

さらに、いくつかの追加アクションがあります。

def exp_reg(X, reg):
    K = jnp.empty_like(X)
    K = jnp.divide(X, -reg)
    return jnp.exp(K)

しかし、私がそれらをテストしたとき:

%timeit jnp.exp(-X/reg).block_until_ready()
%timeit exp_reg(X, reg).block_until_ready()

The second approach turned to outperform, despite having superficially some additional overhead. I've run a %timeit with a matrix of size 2000 x 2000:

7.85 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.19 ms ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Why it may be the case?

4

1 に答える 1

1

ここでの違いは、操作の順序です。

ではjnp.exp(-X/reg)、 のすべてのエントリを否定Xし、結果の各エントリを で割っていregます。これは、配列に対する 2 つのパスXです。

exp_regあなたは否定しreg(おそらくスカラー値ですか?)X、結果で割っています。それは 1 パス オーバーXです。

が大きい場合X、 を複数回通過するため、最初のアプローチは 2 番目のアプローチよりもわずかに遅くなると予想されXます。

さいわい、JAX を使用しているjitため、コードをコンパイルできます。その場合、XLA は通常、これらのような同等の操作順序で最適化できます。実際、2 つの関数については、コンパイルによって不一致が解消されます。

from jax import jit
import jax.numpy as jnp
import numpy as np

def exp_reg1(X, reg):
  return jnp.exp(-X/reg)

def exp_reg2(X, reg):
  K = jnp.divide(X, -reg)
  return jnp.exp(K)

X = jnp.array(np.random.rand(1000, 1000))
reg = 2.0

%timeit exp_reg1(X, reg)
# 100 loops, best of 3: 3.17 ms per loop
%timeit exp_reg2(X, reg)
# 100 loops, best of 3: 2.2 ms per loop

# Trigger compilation
jit(exp_reg1)(X, reg)
jit(exp_reg2)(X, reg)

%timeit jit(exp_reg1)(X, reg)
# 1000 loops, best of 3: 1.92 ms per loop
%timeit jit(exp_reg2)(X, reg)
# 100 loops, best of 3: 1.84 ms per loop

K(補足:操作の結果を同じ名前の変数に代入する前に、空の配列を事前に割り当てる理由はありません)。

于 2020-11-04T18:26:08.270 に答える