Bahdanau の論文を読み、それを現在の tf.contrib.seq2seq API に翻訳した後、デコーダに何を入力するべきか混乱しています。特に、TrainingHelper はタイムシフトされたラベルのリストを受け取る必要があるようです。
以下は私の実際の例ですが、正しいかどうかはわかりません。
# Given:
# annotations: encoder outputs, reshaped to
# (batch_size, time, encoder_size)
# labels: ground truth, shaped (batch_size, FORECAST_HORIZON)
if params.get('ATTENTION') == 'Bahdanau':
bahdanau = tf.contrib.seq2seq.BahdanauAttention(
num_units=ATTENTION_SIZE,
memory=annotations,
normalize=False,
name='BahdanauAttention')
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
cell=tf.nn.rnn_cell.BasicLSTMCell(DECODER_SIZE, forget_bias=1.0),
attention_mechanism=bahdanau,
output_attention=False,
name="attention_wrapper")
helper = tf.contrib.seq2seq.TrainingHelper(
inputs=annotations, # ??????
sequence_length=[WINDOW_LENGTH]*BATCH_SIZE,
name="TrainingDecoderHelper")
最後から 3 番目の行に注意してください。
TrainingHelper は、エンコーダーの注釈をアテンション ラップされたデコーダー システムにフィードすることになっていますか?
- 長所:
inputs
が のような形状でない場合annotations
、AttentionWrapper は形状について不平を言うことになります。システム内でそのような形状が発生する唯一の場所はエンコーダーです。 - con: これが正しい場合、デコーダはどこでグラウンド トゥルースを取得しますか?
- 短所: アテンション ラップ デコーダ (
attn_cell
) は、アノテーションを取得する場所を既に知っています (アテンション メカニズムのポイントではありませんか?)
とにかく、実際に言えば、トレーニング可能なシステムを取得していますが、何か怪しいように思えます (単純な LSTM に比べてパフォーマンスが低いという事実を含め、現時点では間違いなく接線です)。