1

これが私のサンプル LSTM クラスです。英語テキストの文字列内の文字を表す入力 ID は、ワンホット ベクトルとしてエンコードされ、単一の LSTM レイヤーに供給されます。

class MyLSTM(Chain):
    def __init__(self, vocab_size, hidden_size):
        super(MyLSTM, self).__init__(
            mid=L.LSTM(vocab_size, hidden_size),
            out=L.Linear(hidden_size, vocab_size),
        )
        self.W = np.identity(vocab_size).astype(np.float32)

    def reset_state(self):
        self.mid.reset_state()

    def __call__(self, x):
        x_1hot = F.embed_id(x, self.W)
        h = self.mid(x_1hot)
        y = self.out(h)
        return y

完全なコードはこちらです。サンプルの txt ファイルを指定するだけで実行できます

最初のエポックでは、毎秒 20 回の反復を実行しています。3 番目のエポックまでに、1 回の反復に 1 秒以上かかります!

4

0 に答える 0