これは、この質問のフォローアップです。
einsum を使用して (大幅な速度向上を達成するために) 助けを求めていたところ、すばらしい答えが得られました。
また、使用するための提案も得ましたnumba
。私は行って両方を試してみましたが、ある時点以降は速度の増加numba
がはるかに優れているようです.
では、メモリの問題を起こさずに高速化するにはどうすればよいでしょうか?
これは、この質問のフォローアップです。
einsum を使用して (大幅な速度向上を達成するために) 助けを求めていたところ、すばらしい答えが得られました。
また、使用するための提案も得ましたnumba
。私は行って両方を試してみましたが、ある時点以降は速度の増加numba
がはるかに優れているようです.
では、メモリの問題を起こさずに高速化するにはどうすればよいでしょうか?
以下のソリューションは、合計の単純な合計を実行する 3 つの異なる方法と、2 乗の合計を実行する 4 つの異なる方法を示しています。
sum of sums 3 つの方法 - for ループ、JIT for ループ、einsum (いずれもメモリの問題は発生しません)
合計の二乗和 4 つの方法 - for ループ、JIT for ループ、拡張 einsum、中間 einsum
ここでは、最初の 3 つはメモリの問題に遭遇せず、for ループと拡張された einsum は速度の問題に遭遇します。これにより、JIT ソリューションが最良のように見えます。
import numpy as np
import time
from numba import jit
def fun1(Fu, Fv, Fx, Fy, P, B):
Nu = Fu.shape[0]
Nv = Fv.shape[0]
Nx = Fx.shape[0]
Ny = Fy.shape[0]
Nk = Fu.shape[1]
Nl = Fv.shape[1]
I1 = np.zeros([Nu, Nv])
for iu in range(Nu):
for iv in range(Nv):
for ix in range(Nx):
for iy in range(Ny):
S = 0.
for ik in range(Nk):
for il in range(Nl):
S += Fu[iu,ik]*Fv[iv,il]*Fx[ix,ik]*Fy[iy,il]*P[ix,iy]*B[ik,il]
I1[iu, iv] += S
return I1
def fun2(Fu, Fv, Fx, Fy, P, B):
Nu = Fu.shape[0]
Nv = Fv.shape[0]
Nx = Fx.shape[0]
Ny = Fy.shape[0]
Nk = Fu.shape[1]
Nl = Fv.shape[1]
I2 = np.zeros([Nu, Nv])
for iu in range(Nu):
for iv in range(Nv):
for ix in range(Nx):
for iy in range(Ny):
S = 0.
for ik in range(Nk):
for il in range(Nl):
S += Fu[iu,ik]*Fv[iv,il]*Fx[ix,ik]*Fy[iy,il]*P[ix,iy]*B[ik,il]
I2[iu, iv] += S**2.
return I2
if __name__ == '__main__':
Nx = 30
Ny = 40
Nk = 50
Nl = 60
Nu = 70
Nv = 8
Fx = np.random.rand(Nx, Nk)
Fy = np.random.rand(Ny, Nl)
Fu = np.random.rand(Nu, Nk)
Fv = np.random.rand(Nv, Nl)
P = np.random.rand(Nx, Ny)
B = np.random.rand(Nk, Nl)
fjit1 = jit(fun1)
fjit2 = jit(fun2)
# For loop - becomes too slow so commented out
# t = time.time()
# I1 = fun1(Fu, Fv, Fx, Fy, P, B)
# print 'fun1 :', time.time() - t
# JIT compiled for loop - After a certain point beats einsum
t = time.time()
I1jit = fjit1(Fu, Fv, Fx, Fy, P, B)
print 'jit1 :', time.time() - t
# einsum great solution when no squaring is needed
t = time.time()
I1_ = np.einsum('uk, vl, xk, yl, xy, kl->uv', Fu, Fv, Fx, Fy, P, B)
print '1 einsum:', time.time() - t
# For loop - becomes too slow so commented out
# t = time.time()
# I2 = fun2(Fu, Fv, Fx, Fy, P, B)
# print 'fun2 :', time.time() - t
# JIT compiled for loop - After a certain point beats einsum
t = time.time()
I2jit = fjit2(Fu, Fv, Fx, Fy, P, B)
print 'jit2 :', time.time() - t
# Expanded einsum - As the size increases becomes very very slow
# t = time.time()
# I2_ = np.einsum('uk,vl,xk,yl,um,vn,xm,yn,kl,mn,xy->uv', Fu,Fv,Fx,Fy,Fu,Fv,Fx,Fy,B,B,P**2)
# print '2 einsum:', time.time() - t
# Intermediate einsum - As the sizes increase memory can become an issue
t = time.time()
temp = np.einsum('uk, vl, xk, yl, xy, kl->uvxy', Fu, Fv, Fx, Fy, P, B)
I2__ = np.einsum('uvxy->uv', np.square(temp))
print '2 einsum:', time.time() - t
# print 'I1 == I1_ :', np.allclose(I1, I1_)
print 'I1_ == Ijit1_:', np.allclose(I1_, I1jit)
# print 'I2 == I2_ :', np.allclose(I2, I2_)
print 'I2_ == Ijit2_:', np.allclose(I2__, I2jit)
コメント: この回答を自由に編集/改善してください。これを並行させることに関して、誰かが何か提案をしてくれたらうれしいです。
最初に 1 つのインデックスを合計してから、乗算を続行できます。numexpr を最後の乗算およびリダクション演算に投入したバージョンも試しましたが、あまり役に立たないようです。
def fun3(Fu, Fv, Fx, Fy, P, B):
P = P[None, None, ...]
Fu = Fu[:, None, None, None, :]
Fx = Fx[None, None, :, None, :]
Fv = Fv[:, None, None, :]
Fy = Fy[None, :, None, :]
B = B[None, None, ...]
return np.sum((P*np.sum(Fu*Fx*np.sum(Fv*Fy*B, axis=-1)[None, :, None, :, :], axis=-1))**2, axis=(2, 3))
私のコンピューターでははるかに高速です。
jit2 : 7.06 秒
fun3: 0.144 秒
編集:マイナーな改善 - 最初に乗算してから 2 乗します。
Edit2: それぞれが最も得意とすること (numexpr - 乗算、numpy - ドット/テンソルドット、合計) を活用することで、fun3 を 20 回以上改善できます。
def fun4(Fu, Fv, Fx, Fy, P, B):
P = P[None, None, ...]
Fu = Fu[:, None, :]
Fx = Fx[None, ...]
Fy = Fy[:, None, :]
B = B[None, ...]
s = ne.evaluate('Fu*Fx')
r = np.tensordot(Fv, ne.evaluate('Fy*B'), axes=(1, 2))
I = np.tensordot(s, r, axes=(2, 2)).swapaxes(1, 2)
r = ne.evaluate('(P*I)**2')
r = np.sum(r, axis=(2, 3))
return r
fun4: 0.007 秒
さらに、fun8 は (スマート tensordot により) メモリをそれほど消費しないため、より大きな配列を乗算して複数のコアを使用することができます。