2

大きな配列に対して SVD 圧縮を実行しようとすると、Jax で理解できない動作に遭遇しました。サンプルコードは次のとおりです。

@jit 
def jax_compress(L):
    U, S, _ = jsc.linalg.svd(L, 
    full_matrices = False,
    lapack_driver = 'gesvd',
    check_finite=False,
    overwrite_a=True)

    maxS=jnp.max(S)
    chi = jnp.sum(S/maxS>1E-1)

    return chi, jnp.asarray(U)

このコード スニペットを考慮すると、Jax/jit は SciPy よりもパフォーマンスが大幅に向上しますが、最終的には U の次元を減らしたいと考えています。

def jax_process(A):

    chi, U = jax_compress(A)
    
    return U[:,0:chi]

このステップは、次の比較でわかるように、計算時間の点で信じられないほどコストがかかり、SciPy の同等物よりも高くなります。

jax と scipy のベンチマーク

sc_compress上記の jax コードにsc_process相当する SciPy です。ご覧のとおり、SciPy で配列をスライスするコストはほとんどかかりませんが、ヒット関数の出力に適用すると非常にコストがかかります。誰かがこの行動について何らかの洞察を持っていますか?

4

2 に答える 2