Keras で seq2seq アーキテクチャを使用してテキスト要約モデルを構築しようとしています。このチュートリアルhttps://keras.io/examples/lstm_seq2seq/に従い、Embeddings レイヤーで実装しました。これは正常に機能します。でも今はBERTを使いたいです。このようなタスクで事前トレーニング済みの BERT 埋め込みを使用できますか。通常、テキスト分類が表示されますが、BERT で使用されるエンコーダー/デコーダー アーキテクチャは表示されません。
私は TF Hub から BERT モデルにアクセスし、このチュートリアルhttps://github.com/strongio/keras-bert/blob/master/keras-bert.ipynbから実装されたレイヤー クラスを持っています。以下の BERT トークナイザーで適宜トークン化します。私のモデルです
enc_in_id = Input(shape=(None, ), name="Encoder-Input-Ids")
enc_in_mask = Input(shape=(None, ), name="Encoder-Input-Masks")
enc_in_segment = Input(shape=(None, ), name="Encoder-Input-Segment-Ids")
bert_encoder_inputs = [enc_in_id, enc_in_mask, enc_in_segment]
encoder_embeddings = BertLayer(name='Encoder-Bert-Layer')(bert_encoder_inputs)
encoder_embeddings = BatchNormalization(name='Encoder-Batch-Normalization')(encoder_embeddings)
encoder_lstm = LSTM(latent_size, return_state=True, name='Encoder-LSTM')
encoder_out, e_state_h, e_state_c = encoder_lstm(encoder_embeddings)
encoder_states = [e_state_h, e_state_c]
dec_in_id = Input(shape=(None,), name="Decoder-Input-Ids")
dec_in_mask = Input(shape=(None,), name="Decoder-Input-Masks")
dec_in_segment = Input(shape=(None,), name="Decoder-Input-Segment-Ids")
bert_decoder_inputs = [dec_in_id, dec_in_mask, dec_in_segment]
decoder_embeddings_layer = BertLayer(name='Decoder-Bert-Layer')
decoder_embeddings = decoder_embeddings_layer(bert_decoder_inputs)
decoder_batchnorm_layer = BatchNormalization(name='Decoder-Batch-Normalization-1')
decoder_batchnorm = decoder_batchnorm_layer(decoder_embeddings)
decoder_lstm = LSTM(latent_size, return_state=True, return_sequences=True, name='Decoder-LSTM')
decoder_out, _, _ = decoder_lstm(decoder_batchnorm, initial_state=encoder_states)
dense_batchnorm_layer = BatchNormalization(name='Decoder-Batch-Normalization-2')
decoder_out_batchnorm = dense_batchnorm_layer(decoder_out)
decoder_dense_id = Dense(vocabulary_size, activation='softmax', name='Dense-Id')
dec_outputs_id = decoder_dense_id(decoder_out_batchnorm)
モデルが構築され、数エポック後に精度が 1 に上昇し、損失が 0.5 を下回りましたが、予測はひどいものでした。最大 30 個の WordPiece トークンを使用し、同じデータを予測する 5 つのサンプルで構成される開発セットに取り組んでいるため、最初のトークンまたはおそらく 2 つのトークンのみを正しく取得し、最後に見たトークン、または [PAD] を繰り返すだけです。トークン。