で行列をべき乗する 2 つの方法がありjnp = jax.numpy
ます。簡単なもの:
jnp.exp(-X/reg)
さらに、いくつかの追加アクションがあります。
def exp_reg(X, reg):
K = jnp.empty_like(X)
K = jnp.divide(X, -reg)
return jnp.exp(K)
しかし、私がそれらをテストしたとき:
%timeit jnp.exp(-X/reg).block_until_ready()
%timeit exp_reg(X, reg).block_until_ready()
The second approach turned to outperform, despite having superficially some additional overhead. I've run a %timeit
with a matrix of size 2000 x 2000:
7.85 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.19 ms ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Why it may be the case?