152

私は github (リンク)で LSTM 言語モデルのこの例を見ていました。それが一般的に何をするかは、私にはかなり明確です。contiguous()しかし、コード内で数回発生する呼び出しが何をするのかを理解するのにまだ苦労しています。

たとえば、コード入力の 74/75 行で、LSTM のターゲット シーケンスが作成されます。データ ( に格納ids) は 2 次元で、最初の次元はバッチ サイズです。

for i in range(0, ids.size(1) - seq_length, seq_length):
    # Get batch inputs and targets
    inputs = Variable(ids[:, i:i+seq_length])
    targets = Variable(ids[:, (i+1):(i+1)+seq_length].contiguous())

簡単な例として、バッチ サイズ 1 とseq_length10を使用すると、次のようinputsになりtargetsます。

inputs Variable containing:
0     1     2     3     4     5     6     7     8     9
[torch.LongTensor of size 1x10]

targets Variable containing:
1     2     3     4     5     6     7     8     9    10
[torch.LongTensor of size 1x10]

したがって、一般的に私の質問は、何が機能contiguous()し、なぜそれが必要なのですか?

さらに、両方の変数が同じデータで構成されているため、メソッドがターゲットシーケンスに対して呼び出され、入力シーケンスに対して呼び出されない理由がわかりません。

targets非連続でありながら連続している可能性はありinputsますか?


編集:

の呼び出しを省略しようとしましたcontiguous()が、損失を計算するときにエラー メッセージが表示されます。

RuntimeError: invalid argument 1: input is not contiguous at .../src/torch/lib/TH/generic/THTensor.c:231

したがって、明らかcontiguous()にこの例で呼び出す必要があります。

4

8 に答える 8

0

私がこれを理解していることから、より要約された答え:

連続とは、テンソルのメモリ レイアウトが、そのアドバタイズされたメタデータまたは形状情報と一致しないことを示すために使用される用語です。

私の意見では、連続という言葉は紛らわしい/誤解を招く用語です。通常のコンテキストでは、メモリが切断されたブロック (つまり、「連続/接続/連続」) に分散していないことを意味するためです。

一部の操作では、何らかの理由でこの連続したプロパティが必要になる場合があります (GPU での効率など)。

.viewこの問題を引き起こす可能性のある別の操作であることに注意してください。contiguous を呼び出すだけで修正した次のコードを見てください (これを引き起こす典型的な転置の問題ではなく、RNN がその入力に満足していない場合の例です)。

        # normal lstm([loss, grad_prep, train_err]) = lstm(xn)
        n_learner_params = xn_lstm.size(1)
        (lstmh, lstmc) = hs[0] # previous hx from first (standard) lstm i.e. lstm_hx = (lstmh, lstmc) = hs[0]
        if lstmh.size(1) != xn_lstm.size(1): # only true when prev lstm_hx is equal to decoder/controllers hx
            # make sure that h, c from decoder/controller has the right size to go into the meta-optimizer
            expand_size = torch.Size([1,n_learner_params,self.lstm.hidden_size])
            lstmh, lstmc = lstmh.squeeze(0).expand(expand_size).contiguous(), lstmc.squeeze(0).expand(expand_size).contiguous()
        lstm_out, (lstmh, lstmc) = self.lstm(input=xn_lstm, hx=(lstmh, lstmc))

私が得たエラー:

RuntimeError: rnn: hx is not contiguous


ソース/リソース:

于 2020-04-13T18:19:51.883 に答える