大きな配列に対して SVD 圧縮を実行しようとすると、Jax で理解できない動作に遭遇しました。サンプルコードは次のとおりです。
@jit
def jax_compress(L):
U, S, _ = jsc.linalg.svd(L,
full_matrices = False,
lapack_driver = 'gesvd',
check_finite=False,
overwrite_a=True)
maxS=jnp.max(S)
chi = jnp.sum(S/maxS>1E-1)
return chi, jnp.asarray(U)
このコード スニペットを考慮すると、Jax/jit は SciPy よりもパフォーマンスが大幅に向上しますが、最終的には U の次元を減らしたいと考えています。
def jax_process(A):
chi, U = jax_compress(A)
return U[:,0:chi]
このステップは、次の比較でわかるように、計算時間の点で信じられないほどコストがかかり、SciPy の同等物よりも高くなります。
sc_compress
上記の jax コードにsc_process
相当する SciPy です。ご覧のとおり、SciPy で配列をスライスするコストはほとんどかかりませんが、ヒット関数の出力に適用すると非常にコストがかかります。誰かがこの行動について何らかの洞察を持っていますか?