1

Here's my array:

a = [[0.,0.,0.1,0.2], [0.,0.3,0.4,0.3], [0.,0.,0.1,0.]]

I would like to do a where clause which will return the indices of the elements in 'a' where the sum of the values for this element is equal to 1. Something like : where(sum(a) == 1)

Can someone guide me ?

Thanks.

4

2 に答える 2

8
In [1]: import numpy as np

In [2]: a = np.array([[0.,0.,0.1,0.2], [0.,0.3,0.4,0.3], [0.,0.,0.1,0.]])

In [3]: a
Out[3]:
array([[ 0. ,  0. ,  0.1,  0.2],
       [ 0. ,  0.3,  0.4,  0.3],
       [ 0. ,  0. ,  0.1,  0. ]])

In [4]: np.where(np.sum(a,axis=1) == 1)
Out[4]: (array([1]),)

したがって、2 行目 (インデックス == 1) の合計は 1.0 です。np.sum(a, axis=1)行全体の合計をとります。これは、リストの元のリストの要素に相当します。明示的な軸を指定しないと、numpy は配列のすべての要素の合計を取ります。python ビルトインsumnp.sum. from numpy import *これは、物事を明確にしておくべきではない正当な理由です。

アップデート:

@Jaimeが示唆したように、等値との比較は安全ではありません。理想的np.allcloseにはaxisオプションがあるはずですが、ありません。次を使用してこれを再作成できます。

np.where(np.abs(np.sum(a,1) - 1.0) <= 1E-5)

詳細については、ドキュメントを参照しnp.allcloseてください。

于 2013-07-27T15:51:00.420 に答える
2

enumerateを使用して、内包表記をリストします。

>>> a = [[0.,0.,0.1,0.2], [0.,0.3,0.4,0.3], [0.,0.,0.1,0.]]
>>> [i for i, xs in enumerate(a) if sum(xs) == 1]
[1]
于 2013-07-27T15:44:07.947 に答える