たとえば、 is の形状と is の形状の 2 つの ndarrayがtrain_dataset
あります。(10000, 28, 28)
val_dateset
(2000, 28, 28)
反復を使用する以外に、numpy 配列関数を使用して 2 つの ndarray 間の重複を見つける効率的な方法はありますか?
たとえば、 is の形状と is の形状の 2 つの ndarrayがtrain_dataset
あります。(10000, 28, 28)
val_dateset
(2000, 28, 28)
反復を使用する以外に、numpy 配列関数を使用して 2 つの ndarray 間の重複を見つける効率的な方法はありますか?
ここでの Jaime の優れた回答から学んだ 1 つのトリックnp.void
は、入力配列の各行を 1 つの要素として表示するために dtypeを使用することです。これにより、それらを 1D 配列として扱うことができ、それをnp.in1d
または他のset ルーチンの 1 つに渡すことができます。
import numpy as np
def find_overlap(A, B):
if not A.dtype == B.dtype:
raise TypeError("A and B must have the same dtype")
if not A.shape[1:] == B.shape[1:]:
raise ValueError("the shapes of A and B must be identical apart from "
"the row dimension")
# reshape A and B to 2D arrays. force a copy if neccessary in order to
# ensure that they are C-contiguous.
A = np.ascontiguousarray(A.reshape(A.shape[0], -1))
B = np.ascontiguousarray(B.reshape(B.shape[0], -1))
# void type that views each row in A and B as a single item
t = np.dtype((np.void, A.dtype.itemsize * A.shape[1]))
# use in1d to find rows in A that are also in B
return np.in1d(A.view(t), B.view(t))
例えば:
gen = np.random.RandomState(0)
A = gen.randn(1000, 28, 28)
dupe_idx = gen.choice(A.shape[0], size=200, replace=False)
B = A[dupe_idx]
A_in_B = find_overlap(A, B)
print(np.all(np.where(A_in_B)[0] == np.sort(dupe_idx)))
# True
(m, n, ...)
このメソッドは、ブール配列へのブロードキャストを必要としないため、Divakar のメソッドよりもはるかにメモリ効率が高くなります。実際、A
とB
が行優先の場合、コピーはまったく必要ありません。
比較のために、Divakar と BM のソリューションを少し変更しました。
def divakar(A, B):
A.shape = A.shape[0], -1
B.shape = B.shape[0], -1
return (B[:,None] == A).all(axis=(2)).any(0)
def bm(A, B):
t = 'S' + str(A.size // A.shape[0] * A.dtype.itemsize)
ma = np.frombuffer(np.ascontiguousarray(A), t)
mb = np.frombuffer(np.ascontiguousarray(B), t)
return (mb[:, None] == ma).any(0)
In [1]: na = 1000; nb = 200; rowshape = 28, 28
In [2]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
divakar(A, B)
....:
1 loops, best of 3: 244 ms per loop
In [3]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
....:
100 loops, best of 3: 2.81 ms per loop
In [4]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
....:
100 loops, best of 3: 15 ms per loop
ご覧のとおり、BM のソリューションは小さなnの場合、私のソリューションよりもわずかに高速ですがnp.in1d
、すべての要素の等価性をテストするよりも優れています ( O(n²)の複雑さではなくO(n log n) )。
In [5]: na = 10000; nb = 2000; rowshape = 28, 28
In [6]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
bm(A, B)
....:
1 loops, best of 3: 271 ms per loop
In [7]: %%timeit A = gen.randn(na, *rowshape); idx = gen.choice(na, size=nb, replace=False); B = A[idx]
find_overlap(A, B)
....:
10 loops, best of 3: 123 ms per loop
Divakar のソリューションは、私のラップトップではこのサイズの配列に対して扱いにくいです。これは、8GB の RAM しかないのに 15GB の中間配列を生成する必要があるためです。
完全なブロードキャストは、ここで 10000*2000*28*28 =150 Mo ブール配列を生成します。
効率化のために、次のことができます。
200 ko 配列のデータをパックします。
from pylab import *
N=10000
a=rand(N,28,28)
b=a[[randint(0,N,N//5)]]
packedtype='S'+ str(a.size//a.shape[0]*a.dtype.itemsize) # 'S6272'
ma=frombuffer(a,packedtype) # ma.shape=10000
mb=frombuffer(b,packedtype) # mb.shape=2000
%timeit a[:,None]==b : 102 s
%timeit ma[:,None]==mb : 800 ms
allclose((a[:,None]==b).all((2,3)),(ma[:,None]==mb)) : True
ここでは、最初の違いで壊れる遅延文字列比較により、メモリが少なくて済みます。
In [31]: %timeit a[:100]==b[:100]
10000 loops, best of 3: 175 µs per loop
In [32]: %timeit a[:100]==a[:100]
10000 loops, best of 3: 133 µs per loop
In [34]: %timeit ma[:100]==mb[:100]
100000 loops, best of 3: 7.55 µs per loop
In [35]: %timeit ma[:100]==ma[:100]
10000 loops, best of 3: 156 µs per loop
ソリューションはここで与えられます(ma[:,None]==mb).nonzero().
完全な比較に対してin1d
、(Na+Nb) ln(Na+Nb)
複雑さの
ために を使用します。Na*Nb
%timeit in1d(ma,mb).nonzero() : 590ms
ここでは大きな改善ではありませんが、漸近的に改善されています。
def overlap(a,b):
"""
returns a boolean index array for input array b representing
elements in b that are also found in a
"""
a.repeat(b.shape[0],axis=0)
b.repeat(a.shape[0],axis=0)
c = aa == bb
c = c[::a.shape[0]]
return c.all(axis=1)[:,0]
返されたインデックス配列を使用してインデックスを作成b
し、次の場所にもある要素を抽出できます。a
b[overlap(a,b)]
numpy
簡単にするために、この例からすべてをインポートしたと仮定します。
from numpy import *
たとえば、2 つの ndarray が与えられた場合、
a = arange(4*2*2).reshape(4,2,2)
b = arange(3*2*2).reshape(3,2,2)
同じ形になるようにa
andを繰り返しますb
aa = a.repeat(b.shape[0],axis=0)
bb = b.repeat(a.shape[0],axis=0)
aa
次に、との要素を単純に比較できます。bb
c = aa == bb
最後に、 の要素の4 番目ごと、または実際には 1番目ごとの要素を調べて、b
にもある要素のインデックスを取得します。a
shape(a)[0]
c
cc == c[::a.shape[0]]
最後に、サブ配列内のすべての要素が存在する要素のみを含むインデックス配列を抽出します。True
c.all(axis=1)[:,0]
この例では、
array([True, True, True], dtype=bool)
確認するには、の最初の要素を変更しますb
b[0] = array([[50,60],[70,80]])
そして、私たちは得る
array([False, True, True], dtype=bool)