100

I am playing with numpy and digging through documentation and I have come across some magic. Namely I am talking about numpy.where():

>>> x = np.arange(9.).reshape(3, 3)
>>> np.where( x > 5 )
(array([2, 2, 2]), array([0, 1, 2]))

How do they achieve internally that you are able to pass something like x > 5 into a method? I guess it has something to do with __gt__ but I am looking for a detailed explanation.

4

3 に答える 3

77

x > 5 のようなものをメソッドに渡すことができるようにするにはどうすればよいでしょうか?

短い答えは、そうではないということです。

numpy 配列に対するあらゆる種類の論理演算は、ブール配列を返します。(つまり__gt____lt__などはすべて、指定された条件が true の場合にブール配列を返します)。

例えば

x = np.arange(9).reshape(3,3)
print x > 5

収量:

array([[False, False, False],
       [False, False, False],
       [ True,  True,  True]], dtype=bool)

これは、numpy 配列if x > 5:の場合に ValueError を発生させるのと同じ理由です。xこれは、単一の値ではなく、True/False 値の配列です。

さらに、numpy 配列はブール配列でインデックス付けできます。たとえば、この場合はx[x>5]yieldです。[6 7 8]

正直なところ、実際に必要になることはかなりまれですnumpy.whereが、ブール配列が であるインデックスを返すだけですTrue。通常、単純なブール値のインデックス付けで必要なことを行うことができます。

于 2011-04-12T22:48:27.520 に答える
25

古い回答 です。ややこしいです。それはあなたの声明が真である場所のLOCATIONS(それらすべて)を提供します。

それで:

>>> a = np.arange(100)
>>> np.where(a > 30)
(array([31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
       65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
       82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98,
       99]),)
>>> np.where(a == 90)
(array([90]),)

a = a*40
>>> np.where(a > 1000)
(array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
       77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
       94, 95, 96, 97, 98, 99]),)
>>> a[25]
1000
>>> a[26]
1040

私は list.index() の代替として使用していますが、他にも多くの用途があります。2D配列で使用したことはありません。

http://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

新しい回答 その人はもっと根本的なことを尋ねていたようです。

問題は、関数 (どこなど) が要求されたものを知ることができるようにするものをどのように実装できるかということでした。

最初に、比較演算子を呼び出すと興味深いことが行われることに注意してください。

a > 1000
array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True`,  True,  True,  True,  True,  True,  True,  True,  True,  True], dtype=bool)`

これは、"__gt__" メソッドをオーバーロードすることによって行われます。例えば:

>>> class demo(object):
    def __gt__(self, item):
        print item


>>> a = demo()
>>> a > 4
4

ご覧のとおり、「a > 4」は有効なコードでした。

ここで、オーバーロードされたすべての関数の完全なリストとドキュメントを取得できます: http://docs.python.org/reference/datamodel.html

驚くべきことは、これを行うのがいかに簡単かということです。Python でのすべての操作は、このような方法で行われます。a > b と言うのは a と同じです。gt (ロ)!

于 2011-04-12T22:45:36.457 に答える
3

np.where呼び出された numpy ndarray の次元に等しい長さのタプルを返し (つまりndim)、タプルの各項目は、条件が True である初期 ndarray 内のすべての値のインデックスの numpy ndarray です。(寸法と形状を混同しないでください)

例えば:

x=np.arange(9).reshape(3,3)
print(x)
array([[0, 1, 2],
      [3, 4, 5],
      [6, 7, 8]])
y = np.where(x>4)
print(y)
array([1, 2, 2, 2], dtype=int64), array([2, 0, 1, 2], dtype=int64))


y は 2 であるため、長さ 2 のタプルですx.ndim。タプルの 1 番目の項目には 4 より大きいすべての要素の行番号が含まれ、2 番目の項目には 4 より大きいすべての項目の列番号が含まれます。ご覧のとおり、[1,2,2 ,2] は 5,6,7,8 の行番号に対応し、[2,0,1,2] は 5,6,7,8 の列番号に対応します。 )。

同様に、

x=np.arange(27).reshape(3,3,3)
np.where(x>4)


x は 3 次元であるため、長さ 3 のタプルを返します。

しかし、待ってください。np.where には他にもあります。

に 2 つの追加引数が追加されたときnp.where。上記のタプルによって取得されるすべての対ごとの行と列の組み合わせに対して置換操作を実行します。

x=np.arange(9).reshape(3,3)
y = np.where(x>4, 1, 0)
print(y)
array([[0, 0, 0],
   [0, 0, 1],
   [1, 1, 1]])
于 2018-03-28T18:33:31.710 に答える