9

たとえば、任意の NxM 行列があります。

1 2 3 4 5 6
7 8 9 0 1 2
3 4 5 6 7 8
9 0 1 2 3 4

このマトリックスのすべての 3x3 サブマトリックスのリストを取得したい:

1 2 3       2 3 4               0 1 2
7 8 9   ;   8 9 0   ;  ...  ;   6 7 8
3 4 5       4 5 6               2 3 4

ネストされた 2 つのループでこれを行うことができます。

rows, cols = input_matrix.shape
patches = []
for row in np.arange(0, rows - 3):
    for col in np.arange(0, cols - 3):
        patches.append(input_matrix[row:row+3, col:col+3])

しかし、大きな入力行列の場合、これは遅くなります。numpyでこれをより速く行う方法はありますか?

を見てきましたがnp.split、重複しない部分行列が得られますが、重複に関係なく、可能なすべての部分行列が必要です。

4

1 に答える 1

12

ウィンドウ ビューが必要な場合:

from numpy.lib.stride_tricks import as_strided

arr = np.arange(1, 25).reshape(4, 6) % 10
sub_shape = (3, 3)
view_shape = tuple(np.subtract(arr.shape, sub_shape) + 1) + sub_shape
arr_view = as_strided(arr, view_shape, arr.strides * 2
arr_view = arr_view.reshape((-1,) + sub_shape)

>>> arr_view
array([[[[1, 2, 3],
         [7, 8, 9],
         [3, 4, 5]],

        [[2, 3, 4],
         [8, 9, 0],
         [4, 5, 6]],

        ...

        [[9, 0, 1],
         [5, 6, 7],
         [1, 2, 3]],

        [[0, 1, 2],
         [6, 7, 8],
         [2, 3, 4]]]])

このようにすることの良い点は、データをコピーするのではなく、元の配列のデータに別の方法でアクセスすることです。大きな配列の場合、これによりメモリが大幅に節約されます。

于 2013-10-16T22:28:46.440 に答える