これが私のサンプル LSTM クラスです。英語テキストの文字列内の文字を表す入力 ID は、ワンホット ベクトルとしてエンコードされ、単一の LSTM レイヤーに供給されます。
class MyLSTM(Chain):
def __init__(self, vocab_size, hidden_size):
super(MyLSTM, self).__init__(
mid=L.LSTM(vocab_size, hidden_size),
out=L.Linear(hidden_size, vocab_size),
)
self.W = np.identity(vocab_size).astype(np.float32)
def reset_state(self):
self.mid.reset_state()
def __call__(self, x):
x_1hot = F.embed_id(x, self.W)
h = self.mid(x_1hot)
y = self.out(h)
return y
完全なコードはこちらです。サンプルの txt ファイルを指定するだけで実行できます 。
最初のエポックでは、毎秒 20 回の反復を実行しています。3 番目のエポックまでに、1 回の反復に 1 秒以上かかります!