11

2次元のソートされた配列の各行の最初のゼロ以外の値を見つけるための最速の方法を見つけようとしています。技術的には、配列内の値は0と1のみであり、「ソート」されます。

たとえば、配列は次のようになります。

v =

0 0 0 1 1 1 1 
0 0 0 1 1 1 1 
0 0 0 0 1 1 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 1 
0 0 0 0 0 0 0

argmax関数を使用できます

argmax(v, axis=1))

ゼロから1に変化するタイミングを見つけるために、しかしこれは各行に沿って徹底的な検索を行うと思います。私のアレイは適度なサイズ(〜2000x2000)になります。argmaxは、forループ内の各行に対して検索ソートされたアプローチを実行するだけで、それでも優れたパフォーマンスを発揮しますか、それともより良い代替手段がありますか?

また、配列は常に、行の1の最初の位置が常に> =その上の行の1の最初の位置になるようになります(ただし、最後の数行に1が存在することは保証されません)。 )。これをforループと、前の行の最初の1の位置に等しい各行の「開始インデックス値」で利用できますが、numpyargmax関数はPythonで記述されたループよりも優れていると考えるのは正しいです。 。

代替案をベンチマークするだけですが、配列のエッジの長さはかなり変わる可能性があります(250から10,000)。

4

2 に答える 2

7

np.whereを使用するのはかなり速いです:

>>> a
array([[0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 0, 0]])
>>> np.where(a>0)
(array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 4, 5]), array([3, 4, 5, 6, 3, 4, 5, 6, 4, 5, 6, 6, 6, 6]))

これにより、0より大きい値の座標を持つタプルが配信されます。

np.whereを使用して、各サブ配列をテストすることもできます。

def first_true1(a):
    """ return a dict of row: index with value in row > 0 """
    di={}
    for i in range(len(a)):
        idx=np.where(a[i]>0)
        try:
            di[i]=idx[0][0]
        except IndexError:
            di[i]=None    

    return di       

プリント:

{0: 3, 1: 3, 2: 4, 3: 6, 4: 6, 5: 6, 6: None}

つまり、行0:インデックス3> 0; 行4:インデックス4> 0; 行6:0より大きいインデックスはありません

ご想像のとおり、argmaxの方が速い場合があります。

def first_true2():
    di={}
    for i in range(len(a)):
        idx=np.argmax(a[i])
        if idx>0:
            di[i]=idx
        else:
            di[i]=None    

    return di       
    # same dict is returned...

Noneすべてのnaughtsのfor行がないというロジックに対処できる場合、これはさらに高速です。

def first_true3():
    di={}
    for i, j in zip(*np.where(a>0)):
        if i in di:
            continue
        else:
            di[i]=j

    return di      

そして、これがargmaxでaxisを使用するバージョンです(コメントで示唆されているように):

def first_true4():
    di={}
    for i, ele in enumerate(np.argmax(a,axis=1)):
        if ele==0 and a[i][0]==0:
            di[i]=None
        else:
            di[i]=ele

    return di          

(例の配列での)速度比較のために、私はこれを取得します:

            rate/sec usec/pass first_true1 first_true2 first_true3 first_true4
first_true1   23,818    41.986          --      -34.5%      -63.1%      -70.0%
first_true2   36,377    27.490       52.7%          --      -43.6%      -54.1%
first_true3   64,528    15.497      170.9%       77.4%          --      -18.6%
first_true4   79,287    12.612      232.9%      118.0%       22.9%          --

これを2000X2000 np配列にスケーリングすると、次のようになります。

            rate/sec  usec/pass first_true3 first_true1 first_true2 first_true4
first_true3        3 354380.107          --       -0.3%      -74.7%      -87.8%
first_true1        3 353327.084        0.3%          --      -74.6%      -87.7%
first_true2       11  89754.200      294.8%      293.7%          --      -51.7%
first_true4       23  43306.494      718.3%      715.9%      107.3%          --
于 2012-07-31T04:02:39.180 に答える
4

argmax()は経営幹部レベルのループを使用します。Pythonループよりもはるかに高速なので、Pythonでスマートなアルゴリズムを記述しても、argmax()に勝るものはないと思います。Cythonを使用して高速化できます。

@cython.boundscheck(False)
@cython.wraparound(False) 
def find(int[:,:] a):
    cdef int h = a.shape[0]
    cdef int w = a.shape[1]
    cdef int i, j
    cdef int idx = 0
    cdef list r = []
    for i in range(h):
        for j in range(idx, w):
            if a[i, j] == 1:
                idx = j
                r.append(idx)
                break
        else:
            r.append(-1)
    return r

私のPCfor2000x2000マトリックスでは、100us対3msです。

于 2012-07-31T02:15:59.487 に答える