tf2 モデルを呼び出すと、tf Model サブクラスで call() メソッドを定義した方法に従って返されるはずの値が返されません。
代わりに、モデルの call() メソッドを呼び出すと、build() メソッドで定義したテンソルが返されます
なぜこれが起こっているのですか?どうすれば修正できますか?
import numpy as np
import tensorflow as tf
num_items = 1000
emb_dim = 32
lstm_dim = 32
class rnn_model(tf.keras.Model):
def __init__(self, num_items, emb_dim):
super(rnn_model, self).__init__()
self.emb = tf.keras.layers.Embedding(num_items, emb_dim, name='embedding_layer')
self.GRU = tf.keras.layers.LSTM(lstm_dim, name='rnn_layer')
self.dense = tf.keras.layers.Dense(num_items, activation = 'softmax', name='final_layer')
def call(self, inp, is_training=True):
emb = self.emb(inp)
gru = self.GRU(emb)
# logits=self.dense(gru)
return gru # (bs, lstm_dim=50)
def build(self, inp_shape):
x = tf.keras.Input(shape=inp_shape, name='input_layer')
# return tf.keras.Model(inputs=[x], outputs=self.call(x))
return tf.keras.Model(inputs=[x], outputs=self.dense(self.call(x)))
maxlen = 10
model = rnn_model(num_items, emb_dim).build((maxlen, ))
model.summary()
gru_out = model(inp)
print(gru_out.shape) # should have been (bs=16, lstm_dim=32)
以下は、私が得ている出力です-
Model: "functional_11"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_layer (InputLayer) [(None, 10)] 0
_________________________________________________________________
embedding_layer (Embedding) (None, 10, 32) 32000
_________________________________________________________________
rnn_layer (LSTM) (None, 32) 8320
_________________________________________________________________
final_layer (Dense) (None, 1000) 33000
=================================================================
Total params: 73,320
Trainable params: 73,320
Non-trainable params: 0
_________________________________________________________________
(16, 1000)
モデルの最後にある「final_layer」または高密度レイヤーのみを使用して、サンプルされたソフトマックス関数にフィードし、gru_out と一緒に使用して損失を計算します (モデルをトレーニングするため)。
テスト中に、手動で gru_out を model.get_layer('final_layer') に渡して、最終的なロジットを取得するつもりです。