-1

このチュートリアルを例として使用して、カフェのカスタム トレーニング関数を作成しています。セクション 15 には、次のコードがあります。

def train():
    niter = 200
    test_interval = 25 
    train_loss = zeros(niter)
    test_acc = zeros(int(np.ceil(niter / test_interval)))

    ### HERE ###
    output = zeros((niter, 8, 10))
    ###      ###

行 8 には、ndarray(出力) があり、このコードの意味は何ですか。とはどういう意味ですか(niter, 8, 10)? なぜniter、なぜ 8、なぜ 10 なのですか? 自分のデータセットに従ってこの配列を変更する必要がありますか? はいの場合、どのディメンションを使用すればよいですか? 誰かが私にそれを説明できますか?

4

2 に答える 2

2

チュートリアルをよく読むと、数字の分類、つまり10 個のクラスを扱っていることがわかります。さらに、彼らはトリックを使って 8 つの例を並べて並べています (セクション 11、 の近くIn [11]:):

# 最初の 8 つの画像を並べて表示するためのちょっとしたトリックを使用します

したがって、8次元。

セクション 15 では、ネットワークの進行状況を追跡する例を示します。反復ごとの出力予測確率を保存します。反復ごとに10 個のクラス x 8 個の例があり、niter追跡する反復があります。この情報はすべて 3Doutput配列に格納されます。

于 2015-11-25T17:39:05.857 に答える
1

float 0 の 200 * 8 * 10 配列を作成するnumpy.zeroswhereの呼び出しのように見えます。shape = (niter, 8, 10)

于 2015-11-25T17:26:10.657 に答える