0

ufunc を使用して、N * 1 numpy 配列の int を N * 3 numpy 配列の float に効率的にマップしようとしています。

私がこれまでに持っているもの:

map = {1: (0, 0, 0), 2: (0.5, 0.5, 0.5), 3: (1, 1, 1)}
ufunc = numpy.frompyfunc(lambda x: numpy.array(map[x], numpy.float32), 1, 1)

input = numpy.array([1, 2, 3], numpy.int32)

ufunc(input)dtype オブジェクトを含む 3 * 3 配列を返します。この配列が欲しいのですが、dtypeはfloat32です。

4

4 に答える 4

1

ndarray fancy indexを使用して同じ結果を得ることができます。これは、frompyfuncよりも高速である必要があると思います。

map_array = np.array([[0,0,0],[0,0,0],[0.5,0.5,0.5],[1,1,1]], dtype=np.float32)
index = np.array([1,2,3,1])
map_array[index]

または、リスト内包表記を使用することもできます。

map = {1: (0, 0, 0), 2: (0.5, 0.5, 0.5), 3: (1, 1, 1)}
np.array([map[i] for i in [1,2,3,1]], dtype=np.float32)    
于 2012-08-31T01:29:06.047 に答える
1

マッピングがnumpy配列の場合は、次のように派手なインデックスを使用できます。

>>> valmap = numpy.array([(0, 0, 0), (0.5, 0.5, 0.5), (1, 1, 1)])
>>> input = numpy.array([1, 2, 3], numpy.int32)
>>> valmap[input-1]
array([[ 0. ,  0. ,  0. ],
       [ 0.5,  0.5,  0.5],
       [ 1. ,  1. ,  1. ]])
于 2012-08-31T01:29:14.280 に答える
1

私がドキュメントを読み違えていなければnp.frompyfunc、スカラー オブジェクトの出力は確かに: andarrayを入力として使用すると、ndarraywithが得られますdtype=obj

np.vectorize回避策は、次の関数を使用することです。

F = np.vectorize(lambda x: mapper.get(x), 'fff')

ここでは、dtypeofFの出力を強制的に 3 つの float にします (したがって、'fff')。

>>> mapper = {1: (0, 0, 0), 2: (0.5, 1.0, 0.5), 3: (1, 2, 1)}
>>> inp = [1, 2, 3]
>>> F(inp)
(array([ 0. ,  0.5,  1. ], dtype=float32), array([ 0.,  0.5,  1.], dtype=float32), array([ 0. ,  0.5,  1. ], dtype=float32))

これは 3 つの float 配列のタプルであり ('fff' を指定したように)、最初の配列は と同等[mapper[i][0] for i in inp]です。したがって、少し操作すると、次のようになります。

>>> np.array(F(inp)).T
array([[ 0. ,  0. ,  0. ],
       [ 0.5,  0.5,  0.5],
       [ 1. ,  1. ,  1. ]], dtype=float32)
于 2012-08-31T11:59:58.677 に答える
1

np.hstackを使用できます:

import numpy as np
mapping = {1: (0, 0, 0), 2: (0.5, 0.5, 0.5), 3: (1, 1, 1)}
ufunc = np.frompyfunc(lambda x: np.array(mapping[x], np.float32), 1, 1, dtype = np.float32)

data = np.array([1, 2, 3], np.int32)
result = np.hstack(ufunc(data))
print(result)
# [ 0.   0.   0.   0.5  0.5  0.5  1.   1.   1. ]
print(result.dtype)
# float32
print(result.shape)
# (9,)
于 2012-08-31T01:24:10.130 に答える