0

誰でも私に以下のコードを説明してもらえますか:

import torch
import torch.nn as nn

input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)

rnn = nn.LSTM(10,20,2)

output, (hn, cn) = rnn(input, (h0, c0))
print(input)

rnnrnn(input, (h0, c0))を呼び出すときに、引数 h0 と c0 を括弧で囲みました。それはどういう意味ですか?(h0, c0) が単一の値を表す場合、その値は何で、ここで渡される 3 番目の引数は何ですか? ただし、行rnn = nn.LSTM(10,20,2) では、括弧なしで LSTM 関数に引数を渡しています。この関数呼び出しがどのように機能するかを誰かに説明してもらえますか?

4

1 に答える 1

0

割り当ては、クラスを使用してrnn = nn.LSTM(10, 20, 2)新しいインスタンスを作成します。最初の 3 つの引数は(here )、(here )、(here ) です。nn.Modulenn.LSTMinput_size10hidden_size20num_layers2

一方rnn(input, (h0, c0))、クラスインスタンスを実際に呼び出すことに対応し、i.e.実行中です。これは、そのモジュール__call__の機能とほぼ同等です。forward__call__メソッドはnn.LSTM2 つのパラメーターを受け取ります: input (shaped (sequnce_length, batch_size, input_size)、および 2 つのテンソルのタプル(h_0, c_0)(どちらも(num_layers, batch_size, hidden_size) の基本的なユース ケースで整形されていますnn.LSTM)

ビルトインを使用するときはいつでも PyTorch のドキュメントを参照してください。パラメーター リスト (クラス インスタンスの初期化に使用される引数) の正確な定義と、入力/出力仕様 (そのモジュールで推論するときはいつでも) を見つけることができます。


表記に混乱するかもしれませんが、ここに役立つ小さな例があります:

  • 入力としてのタプル:

    def fn1(x, p):
        a, b = p # unpack input
        return a*x + b
    
    >>> fn1(2, (3, 1))
    >>> 7
    
  • 出力としてのタプル

    def fn2(x):
        return x, (3*x, x**2) # actually output is a tuple of int and tuple 
    
    >>> x, (a, b) = fn2(2) # unpacking
    (2, (6, 4))
    
    >>> x, a, b
    (2, 6, 4)
    
于 2021-07-20T21:22:20.913 に答える