3

サイズ(149797, 64)の 2D UINT8 numpy 配列があります。各要素は 0 または 1 です。各行のこれらのバイナリ値をUINT64値にパックして、結果として形状 149797 の UINT64 配列を取得したいと考えています。numpy bitpack 関数を使用して次のコードを試しました。

test = np.random.randint(0, 2, (149797, 64),dtype=np.uint8)
col_pack=np.packbits(test.reshape(-1, 8, 8)[:, ::-1]).view(np.uint64)

packbits 関数の実行には約10 ミリ秒かかります。この配列自体の単純な 再形成には約7ミリ秒かかるようです。また、シフト操作を使用して2次元のnumpy配列を反復して、同じ結果を達成しようとしました。しかし、速度の改善はありませんでした。

最後に、CPU 用のnumbaを使用してコンパイルしたいとも考えています。

@njit
def shifting(bitlist):
    x=np.zeros(149797,dtype=np.uint64)  #54
    rows,cols=bitlist.shape
    for i in range(0,rows):             #56
      out=0
      for bit in range(0,cols):
         out = (out << 1) | bitlist[i][bit] # If i comment out bitlist, time=190 microsec
      x[i]=np.uint64(out)  # Reduces time to microseconds if line is commented in njit
    return x

njitを使用すると、約6 ミリ秒かかります。

これがパラレルnjitバージョンです

@njit(parallel=True)
def shifting(bitlist): 
    rows,cols=149797,64
    out=0
    z=np.zeros(rows,dtype=np.uint64)
    for i in prange(rows):
      for bit in range(cols):
         z[i] = (z[i] * 2) + bitlist[i,bit] # Time becomes 100 micro if i use 'out' instead of 'z[i] array'

    return z

3.24msの実行時間( google colabデュアル コア 2.2Ghz) でわずかに優れています。

この変換をさらに高速化するにはどうすればよいでしょうか? スピードアップを達成するために、ベクトル化 (または並列化)、ビット配列などを使用する余地はありますか?

参照: uint16 配列への numpy packbits パック

12コアのマシン(Intel(R) Xeon(R) CPU E5-1650 v2 @ 3.50GHz) では、

Pauls 法: 1595.0マイクロ秒 (マルチコアを使用していないと思われます)

Numba コード: 146.0マイクロ秒 (前述の parallel-numba)

つまり、約10倍のスピードアップ!!!

4

2 に答える 2

3

byteswap再形成などの代わりに使用することで、かなりのスピードアップを得ることができます。

test = np.random.randint(0, 2, (149797, 64),dtype=np.uint8)

np.packbits(test.reshape(-1, 8, 8)[:, ::-1]).view(np.uint64)
# array([ 1079982015491401631,   246233595099746297, 16216705265283876830,
#        ...,  1943876987915462704, 14189483758685514703,
       12753669247696755125], dtype=uint64)
np.packbits(test).view(np.uint64).byteswap()
# array([ 1079982015491401631,   246233595099746297, 16216705265283876830,
#        ...,  1943876987915462704, 14189483758685514703,
       12753669247696755125], dtype=uint64)

timeit(lambda:np.packbits(test.reshape(-1, 8, 8)[:, ::-1]).view(np.uint64),number=100)
# 1.1054180909413844

timeit(lambda:np.packbits(test).view(np.uint64).byteswap(),number=100)
# 0.18370431219227612
于 2020-02-07T21:49:24.630 に答える