1

フレームごとの分類のためにLSTMにフィードしたい非常に長い時系列があります。

私のデータはフレームごとにラベル付けされており、発生以来、分類に大きな影響を与えるいくつかのまれなイベントが発生することを知っています.

したがって、意味のある予測を得るには、シーケンス全体をフィードする必要があります。

非常に長いシーケンスを LSTM に入力するだけでは、通常の RNN と同様に勾配が消失または爆発するため、最適ではないことが知られています。


シーケンスをより短い (たとえば、100 の長さの) シーケンスにカットする単純な手法を使用し、それぞれで LSTM を実行してから、最終的な LSTM 非表示およびセル状態を次のフォワード パスの開始非表示およびセル状態として渡したいと思いました。 .

これは、まさにそれを行った人の例ですそこでは、「時間による切り捨てられた逆伝播」と呼ばれます。私は同じ作品を作ることができませんでした。


Pytorch ライトニングでの私の試み (無関係な部分を取り除いたもの):

def __init__(self, config, n_classes, datamodule):
    ...
    self._criterion = nn.CrossEntropyLoss(
        reduction='mean',
    )

    num_layers = 1
    hidden_size = 50
    batch_size=1

    self._lstm1 = nn.LSTM(input_size=len(self._in_features), hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
    self._log_probs = nn.Linear(hidden_size, self._n_predicted_classes)
    self._last_h_n = torch.zeros((num_layers, batch_size, hidden_size), device='cuda', dtype=torch.double, requires_grad=False)
    self._last_c_n = torch.zeros((num_layers, batch_size, hidden_size), device='cuda', dtype=torch.double, requires_grad=False)

def training_step(self, batch, batch_index):
    orig_batch, label_batch = batch
    n_labels_in_batch = np.prod(label_batch.shape)
    lstm_out, (self._last_h_n, self._last_c_n) = self._lstm1(orig_batch, (self._last_h_n, self._last_c_n))
    log_probs = self._log_probs(lstm_out)
    loss = self._criterion(log_probs.view(n_labels_in_batch, -1), label_batch.view(n_labels_in_batch))

    return loss

このコードを実行すると、次のエラーが発生します。

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

追加しても同じことが起こります

def on_after_backward(self) -> None:
    self._last_h_n.detach()
    self._last_c_n.detach()

使用するとエラーは発生しません

lstm_out, (self._last_h_n, self._last_c_n) = self._lstm1(orig_batch,)

しかし、現在のフレームバッチからの出力が次のフレームバッチに転送されないため、明らかにこれは役に立ちません。


このエラーの原因は何ですか? h_n出力を切り離すc_nだけで十分だと思いました。

前のフレーム バッチの出力を次のフレーム バッチに渡し、各フレーム バッチを個別にトーチ バック プロパゲートするにはどうすればよいですか?

4

1 に答える 1