4

派手なページで、彼らは次の例を示しています

s = np.random.dirichlet((10, 5, 3), 20)

これはすべて問題なく素晴らしいことです。しかし、アルファの 2D 配列からランダムなサンプルを生成したい場合はどうすればよいでしょうか?

alphas = np.random.randint(10, size=(20, 3))

、、またはを試すと
np.random.dirichlet(alphas)、結果は になり ます。動作しているように見える唯一のものは次のとおりです。
np.random.dirichlet([x for x in alphas])
np.random.dirichlet((x for x in alphas))

ValueError: object too deep for desired array

y = np.empty(alphas.shape)
for i in xrange(np.alen(alphas)):
    y[i] = np.random.dirichlet(alphas[i])
    print y

...これは、私のコード構造にとって理想とはほど遠いものです。なぜこれが当てはまるのですか?これを行うためのより「numpyのような」方法を誰かが考えることができますか?

前もって感謝します。

4

2 に答える 2

4

np.random.dirichlet単一のディリクレ分布のサンプルを生成するように書かれています。そのコードはガンマ分布に関して実装されており、その実装をベクトル化されたコードの基礎として使用して、さまざまな分布からサンプルを生成できます。以下でdirichlet_sampleは、形状 (n, k) の配列alphasを受け取ります。ここで、各行はalphaディリクレ分布のベクトルです。これも形状 (n, k) の配列を返します。各行は からの対応する分布のサンプルですalphas。スクリプトとして実行すると、 と を使用dirichlet_sampleしてサンプルが生成さnp.random.dirichletれ、それらが同じサンプルを生成していることを確認します (通常の浮動小数点の違いまで)。

import numpy as np


def dirichlet_sample(alphas):
    """
    Generate samples from an array of alpha distributions.
    """
    r = np.random.standard_gamma(alphas)
    return r / r.sum(-1, keepdims=True)


if __name__ == "__main__":
    alphas = 2 ** np.random.randint(0, 4, size=(6, 3))

    np.random.seed(1234)
    d1 = dirichlet_sample(alphas)
    print "dirichlet_sample:"
    print d1

    np.random.seed(1234)
    d2 = np.empty(alphas.shape)
    for k in range(len(alphas)):
        d2[k] = np.random.dirichlet(alphas[k])
    print "np.random.dirichlet:"
    print d2

    # Compare d1 and d2:
    err = np.abs(d1 - d2).max()
    print "max difference:", err

サンプルラン:

dirichlet_sample:
[[ 0.38980834  0.4043844   0.20580726]
 [ 0.14076375  0.26906604  0.59017021]
 [ 0.64223074  0.26099934  0.09676991]
 [ 0.21880145  0.33775249  0.44344606]
 [ 0.39879859  0.40984454  0.19135688]
 [ 0.73976425  0.21467288  0.04556287]]
np.random.dirichlet:
[[ 0.38980834  0.4043844   0.20580726]
 [ 0.14076375  0.26906604  0.59017021]
 [ 0.64223074  0.26099934  0.09676991]
 [ 0.21880145  0.33775249  0.44344606]
 [ 0.39879859  0.40984454  0.19135688]
 [ 0.73976425  0.21467288  0.04556287]]
max difference: 5.55111512313e-17
于 2013-04-10T05:01:56.260 に答える