4

2 つの numpy ブール配列 (aおよびb) があります。それらの要素のいくつが等しいかを見つける必要があります。現在、私はそうしていますlen(a) - (a ^ b).sum()が、私が理解しているように、xor操作はまったく新しいnumpy配列を作成します。不要な一時配列を作成せずに、この望ましい動作を効率的に実装するにはどうすればよいですか?

numexpr を使用してみましたが、うまく動作しません。True が 1 で False が 0 であるという概念をサポートしていないため、 を使用する必要ne.evaluate("sum(where(a==b, 1, 0))")があり、約 2 倍の時間がかかります。

編集:これらの配列の 1 つが実際には異なるサイズの別の配列へのビューであり、両方の配列が不変であると見なされるべきであることに言及するのを忘れていました。どちらの配列も 2 次元で、サイズは 25x40 前後になる傾向があります。

はい、これは私のプログラムのボトルネックであり、最適化する価値があります。

4

4 に答える 4

2

私のマシンでは、これはより高速です:

(a == b).sum()

余分なストレージを使用したくない場合は、numba を使用することをお勧めします。あまり詳しくありませんが、これはうまくいくようです。Cython にブール値の NumPy 配列を取得させるのに問題が発生しました。

from numba import autojit
def pysumeq(a, b):
    tot = 0
    for i in xrange(a.shape[0]):
        for j in xrange(a.shape[1]):
            if a[i,j] == b[i,j]:
                tot += 1
    return tot
# make numba version
nbsumeq = autojit(pysumeq)
A = (rand(10,10)<.5)
B = (rand(10,10)<.5)
# do a simple dry run to get it to compile
# for this specific use case
nbsumeq(A, B)

numba をお持ちでない場合は、@ user2357112 の回答を使用することをお勧めします

編集: Cython バージョンが動作するようになりました.pyx。ファイルは次のとおりです。私はこれで行きます。

from numpy cimport ndarray as ar
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def cysumeq(ar[np.uint8_t,ndim=2,cast=True] a, ar[np.uint8_t,ndim=2,cast=True] b):
    cdef int i, j, h=a.shape[0], w=a.shape[1], tot=0
    for i in xrange(h):
        for j in xrange(w):
            if a[i,j] == b[i,j]:
                tot += 1
    return tot
于 2013-07-31T04:08:12.290 に答える