0

私は maxpool を知っており、pytorch で使用しています。拡張されたパラメーターを持つ Maxpool は以下のとおりです: ここで、中心要素なしで maxpool を実行する特別な形式の maxpool が必要です。つまり、カーネル サイズは 3X3 ですが、中央の要素を削除する必要があります。したがって、結果は残りの 8 つの要素から得られるはずです。 今、私は for ループを使用しています。numpy や pytorch などを使用してこれを加速するにはどうすればよいですか?maxpool、拡張

import numpy as np
from timeit import default_timer as timer


def MaxPool_special(kh, kw, arr):
    """
    to do maxpool without central element
    :param kh:  should always be 3
    :param kw:  should always be 3
    :param arr:  the input array
    :return: arr_res: output array
    """
    h, w = arr.shape[:2]

    arr_res = np.array([[maxpool_ij(i, j, arr, kh, kw) for j in range(w)] for i in range(h)])

    return arr_res


def maxpool_ij(i, j, arr, dh, dw):
    """
    find the maximum value around point(i,j) with dilated parameter
    """
    Mmax = None
    imin, imax = i - dh, i + dh
    jmin, jmax = j - dw, j + dw
    if imin >= 0 and imax < h and jmin >= 0 and jmax < w:
        Mmax = np.max(
            arr[[imin, imin, imin, i, i, imax, imax, imax], [jmin, j, jmax, jmin, jmax, jmin, j, jmax]])
    elif imin < 0 and jmin < 0:
        Mmax = np.max(arr[[i, imax, imax], [jmax, j, jmax]])
    elif imin < 0 and jmax >= w:
        Mmax = np.max(arr[[i, imax, imax], [jmin, jmin, j]])
    elif imax >= h and jmin < 0:
        Mmax = np.max(arr[[imin, imin, i], [j, jmax, jmax]])
    elif imax >= h and jmax >= w:
        Mmax = np.max(arr[[imin, imin, i], [jmin, j, jmin]])
    elif imin < 0:
        Mmax = np.max(arr[[i, i, imax, imax, imax], [jmin, jmax, jmin, j, jmax]])
    elif imax >= h:
        Mmax = np.max(arr[[imin, imin, imin, i, i], [jmin, j, jmax, jmin, jmax]])
    elif jmin < 0:
        Mmax = np.max(arr[[imin, imin, i, imax, imax], [j, jmax, jmax, j, jmax]])
    elif jmax >= w:
        Mmax = np.max(arr[[imin, imin, i, imax, imax], [jmin, j, jmin, jmin, j]])

    assert Mmax, f'Wrong logic above!{imin, imax, jmin, jmax, h, w}'

    return Mmax

#  generate input array
h, w = 400, 500
arr = np.random.randint(0, 256, h * w).reshape(h, w)

tic = timer()
grayPool = MaxPool_special(3, 3, arr)
toc = timer()
print(f'time cost for for-loops: {toc - tic}')

このコードを高速化するのを手伝ってください、ありがとう!

4

1 に答える 1