9

numpy配列を「せん断」したいのですが。「せん断」という用語を正しく使用しているかどうかはわかりません。せん断とは、次のような意味です。

最初の列を0桁
シフトします。2番目の列を1桁
シフトします。3番目の列を2桁シフトし
ます。

したがって、この配列:

array([[11, 12, 13],
       [17, 18, 19],
       [35, 36, 37]])

次のいずれかの配列になります。

array([[11, 36, 19],
       [17, 12, 37],
       [35, 18, 13]])

またはこの配列のようなもの:

array([[11,  0,  0],
       [17, 12,  0],
       [35, 18, 13]])

エッジの処理方法によって異なります。私はエッジの振る舞いにあまりこだわっていません。

これを行う関数での私の試みは次のとおりです。

import numpy

def shear(a, strength=1, shift_axis=0, increase_axis=1, edges='clip'):
    strength = int(strength)
    shift_axis = int(shift_axis)
    increase_axis = int(increase_axis)
    if shift_axis == increase_axis:
        raise UserWarning("Shear can't shift in the direction it increases")
    temp = numpy.zeros(a.shape, dtype=int)
    indices = []
    for d, num in enumerate(a.shape):
        coords = numpy.arange(num)
        shape = [1] * len(a.shape)
        shape[d] = num
        coords = coords.reshape(shape) + temp
        indices.append(coords)
    indices[shift_axis] -= strength * indices[increase_axis]
    if edges == 'clip':
        indices[shift_axis][indices[shift_axis] < 0] = -1
        indices[shift_axis][indices[shift_axis] >= a.shape[shift_axis]] = -1
        res = a[indices]
        res[indices[shift_axis] == -1] = 0
    elif edges == 'roll':
        indices[shift_axis] %= a.shape[shift_axis]
        res = a[indices]
    return res

if __name__ == '__main__':
    a = numpy.random.random((3,4))
    print a
    print shear(a)

うまくいくようです。そうでない場合は教えてください!

また、不格好でエレガントではないようです。これを行う組み込みのnumpy/scipy関数を見落としていますか?numpyでこれを行うためのよりクリーン/より良い/より効率的な方法はありますか?私は車輪の再発明をしていますか?

編集:
これが2Dの場合だけでなく、N次元配列で機能する場合のボーナスポイント。

この関数は、データ処理で何度も繰り返すループの中心にあるため、実際に最適化する価値があると思います。

2番目の編集:私はついにいくつかのベンチマークを行いました。ループにもかかわらず、numpy.rollが進むべき道のようです。ありがとう、tom10とSven Marnach!

ベンチマークコード:(Windowsで実行し、Linuxではtime.clockを使用しないでください)

import time, numpy

def shear_1(a, strength=1, shift_axis=0, increase_axis=1, edges='roll'):
    strength = int(strength)
    shift_axis = int(shift_axis)
    increase_axis = int(increase_axis)
    if shift_axis == increase_axis:
        raise UserWarning("Shear can't shift in the direction it increases")
    temp = numpy.zeros(a.shape, dtype=int)
    indices = []
    for d, num in enumerate(a.shape):
        coords = numpy.arange(num)
        shape = [1] * len(a.shape)
        shape[d] = num
        coords = coords.reshape(shape) + temp
        indices.append(coords)
    indices[shift_axis] -= strength * indices[increase_axis]
    if edges == 'clip':
        indices[shift_axis][indices[shift_axis] < 0] = -1
        indices[shift_axis][indices[shift_axis] >= a.shape[shift_axis]] = -1
        res = a[indices]
        res[indices[shift_axis] == -1] = 0
    elif edges == 'roll':
        indices[shift_axis] %= a.shape[shift_axis]
        res = a[indices]
    return res

def shear_2(a, strength=1, shift_axis=0, increase_axis=1, edges='roll'):
    indices = numpy.indices(a.shape)
    indices[shift_axis] -= strength * indices[increase_axis]
    indices[shift_axis] %= a.shape[shift_axis]
    res = a[tuple(indices)]
    if edges == 'clip':
        res[indices[shift_axis] < 0] = 0
        res[indices[shift_axis] >= a.shape[shift_axis]] = 0
    return res

def shear_3(a, strength=1, shift_axis=0, increase_axis=1):
    if shift_axis > increase_axis:
        shift_axis -= 1
    res = numpy.empty_like(a)
    index = numpy.index_exp[:] * increase_axis
    roll = numpy.roll
    for i in range(0, a.shape[increase_axis]):
        index_i = index + (i,)
        res[index_i] = roll(a[index_i], i * strength, shift_axis)
    return res

numpy.random.seed(0)
for a in (
    numpy.random.random((3, 3, 3, 3)),
    numpy.random.random((50, 50, 50, 50)),
    numpy.random.random((300, 300, 10, 10)),
    ):
    print 'Array dimensions:', a.shape
    for sa, ia in ((0, 1), (1, 0), (2, 3), (0, 3)):
        print 'Shift axis:', sa
        print 'Increase axis:', ia
        ref = shear_1(a, shift_axis=sa, increase_axis=ia)
        for shear, label in ((shear_1, '1'), (shear_2, '2'), (shear_3, '3')):
            start = time.clock()
            b = shear(a, shift_axis=sa, increase_axis=ia)
            end = time.clock()
            print label + ': %0.6f seconds'%(end-start)
            if (b - ref).max() > 1e-9:
                print "Something's wrong."
        print
4

5 に答える 5

8

tom10's answerのアプローチは、任意の次元に拡張できます。

def shear3(a, strength=1, shift_axis=0, increase_axis=1):
    if shift_axis > increase_axis:
        shift_axis -= 1
    res = numpy.empty_like(a)
    index = numpy.index_exp[:] * increase_axis
    roll = numpy.roll
    for i in range(0, a.shape[increase_axis]):
        index_i = index + (i,)
        res[index_i] = roll(a[index_i], -i * strength, shift_axis)
    return res
于 2011-02-15T16:44:40.177 に答える
8

numpy rollがこれを行います。たとえば、元の配列が x の場合

for i in range(x.shape[1]):
    x[:,i] = np.roll(x[:,i], i)

生産する

[[11 36 19]
 [17 12 37]
 [35 18 13]]
于 2011-02-15T00:02:36.050 に答える
7

これは、Joe Kington によるこの回答で説明されているトリックを使用して実行できます。

from numpy.lib.stride_tricks import as_strided
a = numpy.array([[11, 12, 13],
                 [17, 18, 19],
                 [35, 36, 37]])
shift_axis = 0
increase_axis = 1
b = numpy.vstack((a, a))
strides = list(b.strides)
strides[increase_axis] -= strides[shift_axis]
strides = (b.strides[0], b.strides[1] - b.strides[0])
as_strided(b, shape=b.shape, strides=strides)[a.shape[0]:]
# array([[11, 36, 19],
#        [17, 12, 37],
#        [35, 18, 13]])

「ロール」の代わりに「クリップ」を取得するには、使用します

b = numpy.vstack((numpy.zeros(a.shape, int), a))

これは、Python ループをまったく使用しないため、おそらく最も効率的な方法です。

于 2011-02-14T23:54:26.237 に答える
2

これは、独自のアプローチのクリーンアップされたバージョンです。

def shear2(a, strength=1, shift_axis=0, increase_axis=1, edges='clip'):
    indices = numpy.indices(a.shape)
    indices[shift_axis] -= strength * indices[increase_axis]
    indices[shift_axis] %= a.shape[shift_axis]
    res = a[tuple(indices)]
    if edges == 'clip':
        res[indices[shift_axis] < 0] = 0
        res[indices[shift_axis] >= a.shape[shift_axis]] = 0
    return res

主な違いは、これnumpy.indices()の独自のバージョンをロールする代わりに使用することです。

于 2011-02-15T16:25:17.610 に答える
0
r = lambda l, n: l[n:]+l[:n]

transpose(map(r, transpose(a), range(0, len(a)))

おもう。おそらく、この疑似コードは実際の Python よりも多くを考慮する必要があります。基本的に配列を転置し、その上に一般的な回転関数をマップして回転を行い、転置して戻します。

于 2011-02-14T23:48:38.723 に答える