np.random.dirichlet
単一のディリクレ分布のサンプルを生成するように書かれています。そのコードはガンマ分布に関して実装されており、その実装をベクトル化されたコードの基礎として使用して、さまざまな分布からサンプルを生成できます。以下でdirichlet_sample
は、形状 (n, k) の配列alphas
を受け取ります。ここで、各行はalpha
ディリクレ分布のベクトルです。これも形状 (n, k) の配列を返します。各行は からの対応する分布のサンプルですalphas
。スクリプトとして実行すると、 と を使用dirichlet_sample
してサンプルが生成さnp.random.dirichlet
れ、それらが同じサンプルを生成していることを確認します (通常の浮動小数点の違いまで)。
import numpy as np
def dirichlet_sample(alphas):
"""
Generate samples from an array of alpha distributions.
"""
r = np.random.standard_gamma(alphas)
return r / r.sum(-1, keepdims=True)
if __name__ == "__main__":
alphas = 2 ** np.random.randint(0, 4, size=(6, 3))
np.random.seed(1234)
d1 = dirichlet_sample(alphas)
print "dirichlet_sample:"
print d1
np.random.seed(1234)
d2 = np.empty(alphas.shape)
for k in range(len(alphas)):
d2[k] = np.random.dirichlet(alphas[k])
print "np.random.dirichlet:"
print d2
# Compare d1 and d2:
err = np.abs(d1 - d2).max()
print "max difference:", err
サンプルラン:
dirichlet_sample:
[[ 0.38980834 0.4043844 0.20580726]
[ 0.14076375 0.26906604 0.59017021]
[ 0.64223074 0.26099934 0.09676991]
[ 0.21880145 0.33775249 0.44344606]
[ 0.39879859 0.40984454 0.19135688]
[ 0.73976425 0.21467288 0.04556287]]
np.random.dirichlet:
[[ 0.38980834 0.4043844 0.20580726]
[ 0.14076375 0.26906604 0.59017021]
[ 0.64223074 0.26099934 0.09676991]
[ 0.21880145 0.33775249 0.44344606]
[ 0.39879859 0.40984454 0.19135688]
[ 0.73976425 0.21467288 0.04556287]]
max difference: 5.55111512313e-17