一般的なルールは、ジョブに正しいアルゴリズムを使用することです。これは、畳み込みカーネルがデータに比べて短い場合を除き、FFT ベースの畳み込みです (短いとは、おおよそ log2(n) 未満を意味します。ここで、n はデータの長さです)。 .
あなたの場合、2 つの等しいサイズのデータセットを畳み込んでいるので、おそらく FFT ベースの畳み込みを検討する必要があります。
明らかに、scipy.signal.fftconvolve
この点でタッチが不足しています。より高速な FFT アルゴリズムを使用すると、独自の畳み込みルーチンをローリングすることで、はるかに優れた結果を得ることができます (fftconvolve が変換サイズを 2 の累乗に強制することは役に立ちません。そうしないと、モンキー パッチが適用される可能性があります)。
次のコードは、 FFTWのラッパーであるpyfftwを使用し、カスタム畳み込みクラスを作成します。CustomFFTConvolution
class CustomFFTConvolution(object):
def __init__(self, A, B, threads=1):
shape = (np.array(A.shape) + np.array(B.shape))-1
if np.iscomplexobj(A) and np.iscomplexobj(B):
self.fft_A_obj = pyfftw.builders.fftn(
A, s=shape, threads=threads)
self.fft_B_obj = pyfftw.builders.fftn(
B, s=shape, threads=threads)
self.ifft_obj = pyfftw.builders.ifftn(
self.fft_A_obj.get_output_array(), s=shape,
threads=threads)
else:
self.fft_A_obj = pyfftw.builders.rfftn(
A, s=shape, threads=threads)
self.fft_B_obj = pyfftw.builders.rfftn(
B, s=shape, threads=threads)
self.ifft_obj = pyfftw.builders.irfftn(
self.fft_A_obj.get_output_array(), s=shape,
threads=threads)
def __call__(self, A, B):
fft_padded_A = self.fft_A_obj(A)
fft_padded_B = self.fft_B_obj(B)
return self.ifft_obj(fft_padded_A * fft_padded_B)
これは次のように使用されます。
custom_fft_conv = CustomFFTConvolution(A, B)
C = custom_fft_conv(A, B) # This can contain different values to during construction
threads
クラスを構築するときにオプションの引数を使用します。クラスを作成する目的は、事前に変換を計画する FFTW の機能を活用することです。
以下の完全なデモ コードは、タイミングなどに対する @Kelsey の回答を単純に拡張したものです。
スピードアップは、numba ソリューションとバニラの fftconvolve ソリューションの両方でかなりのものです。n = 33 の場合、両方よりも約 40 ~ 45 倍高速です。
from timeit import Timer
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import fftconvolve
from numba import jit, double
import pyfftw
# Original code
def custom_convolution(A, B):
dimA = A.shape[0]
dimB = B.shape[0]
dimC = dimA + dimB
C = np.zeros((dimC, dimC, dimC))
for x1 in range(dimA):
for x2 in range(dimB):
for y1 in range(dimA):
for y2 in range(dimB):
for z1 in range(dimA):
for z2 in range(dimB):
x = x1 + x2
y = y1 + y2
z = z1 + z2
C[x, y, z] += A[x1, y1, z1] * B[x2, y2, z2]
return C
# Numba'ing the function with the JIT compiler
numba_convolution = jit(double[:, :, :](double[:, :, :],
double[:, :, :]))(custom_convolution)
def fft_convolution(A, B):
return fftconvolve(A, B, mode='full')
class CustomFFTConvolution(object):
def __init__(self, A, B, threads=1):
shape = (np.array(A.shape) + np.array(B.shape))-1
if np.iscomplexobj(A) and np.iscomplexobj(B):
self.fft_A_obj = pyfftw.builders.fftn(
A, s=shape, threads=threads)
self.fft_B_obj = pyfftw.builders.fftn(
B, s=shape, threads=threads)
self.ifft_obj = pyfftw.builders.ifftn(
self.fft_A_obj.get_output_array(), s=shape,
threads=threads)
else:
self.fft_A_obj = pyfftw.builders.rfftn(
A, s=shape, threads=threads)
self.fft_B_obj = pyfftw.builders.rfftn(
B, s=shape, threads=threads)
self.ifft_obj = pyfftw.builders.irfftn(
self.fft_A_obj.get_output_array(), s=shape,
threads=threads)
def __call__(self, A, B):
fft_padded_A = self.fft_A_obj(A)
fft_padded_B = self.fft_B_obj(B)
return self.ifft_obj(fft_padded_A * fft_padded_B)
def run_test():
reps = 10
nt, ft, cft, cft2 = [], [], [], []
x = range(2, 34)
for N in x:
print N
A = np.random.rand(N, N, N)
B = np.random.rand(N, N, N)
custom_fft_conv = CustomFFTConvolution(A, B)
custom_fft_conv_nthreads = CustomFFTConvolution(A, B, threads=2)
C1 = numba_convolution(A, B)
C2 = fft_convolution(A, B)
C3 = custom_fft_conv(A, B)
C4 = custom_fft_conv_nthreads(A, B)
assert np.allclose(C1[:-1, :-1, :-1], C2)
assert np.allclose(C1[:-1, :-1, :-1], C3)
assert np.allclose(C1[:-1, :-1, :-1], C4)
t = Timer(lambda: numba_convolution(A, B))
nt.append(t.timeit(number=reps))
t = Timer(lambda: fft_convolution(A, B))
ft.append(t.timeit(number=reps))
t = Timer(lambda: custom_fft_conv(A, B))
cft.append(t.timeit(number=reps))
t = Timer(lambda: custom_fft_conv_nthreads(A, B))
cft2.append(t.timeit(number=reps))
plt.plot(x, ft, label='scipy.signal.fftconvolve')
plt.plot(x, nt, label='custom numba convolve')
plt.plot(x, cft, label='custom pyfftw convolve')
plt.plot(x, cft2, label='custom pyfftw convolve with threading')
plt.legend()
plt.show()
if __name__ == '__main__':
run_test()
編集: 最近の scipy は、常に 2 の累乗の長さにパディングするわけではないため、出力が pyFFTW ケースに近くなります。