2

私は速度を念頭に置いて数値アルゴリズムを書いています。scipy/numpy (scipy.linalg.expm2、scipy.linalg.expm) で 2 つの行列指数関数に遭遇しました。ただし、事前に対角であることがわかっている行列があります。これらの scipy 関数は、実行前に行列が対角かどうかをチェックしますか? 明らかに、べき乗アルゴリズムは対角行列に対してはるかに高速になる可能性があります。これらがそれでスマートなことをしていることを確認したいだけです-そうでない場合、それを行う簡単な方法はありますか?

4

3 に答える 3

3

A が対角であることがわかっていて、k 乗が必要な場合:

def dpow(a, k):
    return np.diag(np.diag(a) ** k)

行列が対角かどうかを確認します。

def isdiag(a):
    return np.all(a == np.diag(np.diag(a)))

それで :

def pow(a, k):
    if isdiag(a):
        return dpow(a, k)
    else:
        return np.asmatrix(a) ** k

同様に、指数 (一連の pow の展開から数学的に取得できます) の場合、次のことができます。

def dexp(a, k):
    return np.diag(np.exp(np.diag(a)))

def exp(a, k):
    if isdiag(a):
        return dexp(a, k)
    else:
        #use scipy.linalg.expm2 or whatever
于 2013-04-18T12:20:58.683 に答える
1

HYRY と同じことを高速化するのに役立つツールを開発しましたが、インプレースで実行することにより、次のようになります。

def diagonal(array):
    """ Return a **view** of the diagonal elements of 'array' """
    from numpy.lib.stride_tricks import as_strided
    return as_strided(array,shape=(min(array.shape),),strides=(sum(array.strides),))

# generate a random diagonal array
d  = np.diag(np.random.random(4000))

# in-place exponent of the diagonal elements
ddiag = diagonal(d)
ddiag[:] = np.exp(ddiag)

# timeit comparison with HYRY's method
%timeit -n10 np.diag(np.exp(np.diag(d)))   
    # out> 10 loops, best of 3: 52.1 ms per loop
%timeit -n10 ddiag = diagonal(d); ddiag[:] = np.exp(ddiag)
    # out> 10 loops, best of 3: 108 µs per loop

今、

  • HYRY の方法は、対角線の長さに関して (おそらく新しい配列メモリ割り当てのため) 2 次であるため、行列の次元が小さい場合、差はそれほど大きくない可能性があります。

  • インプレース計算で大丈夫である必要があります

  • 最後に、非対角要素は 0 なので、それらの指数は 1 になるはずですよね? どちらの方法でも、非対角は 0 です。

その最後の部分で、すべての非対角要素を 1 にしたい場合は、次のようにできます。

d2 = np.ones_like(d); 
diagonal(d2)[:] = np.exp(np.diag(d))

print (d2==np.exp(d)).all()  # True

しかし、これは配列サイズに対して線形であるため、対角線の長さに対して二次です。timeit は、4000x4000 アレイで約 90 ミリ秒、2000x2000 アレイで 22.3 ミリ秒になります。

最後に、少し高速化するためにインプレースで実行することもできます。

diag = np.diag(d)
d[:]=1
diagonal(d)[:] = np.exp(diag)

Timeit は、4000^2 配列で 66.1ms、2000^2 で 16.8ms を示します。

于 2013-04-19T12:18:44.673 に答える