Tensorflow 2.0 で多層 RNN モデルを実装しようとしています。両方tf.keras.layers.StackedRNNCells
を試してみると、tf.keras.layers.RNN
同じ結果が得られます。との違いを理解するのを手伝ってくれる人はいtf.keras.layers.RNN
ますtf.keras.layers.StackedRNNCells
か?
# driving parameters
sz_batch = 128
sz_latent = 200
sz_sequence = 196
sz_feature = 2
n_units = 120
n_layers = 3
マルチレイヤ RNN とtf.keras.layers.RNN
:
inputs = tf.keras.layers.Input(batch_shape=(sz_batch, sz_sequence, sz_feature))
cells = [tf.keras.layers.GRUCell(n_units) for _ in range(n_layers)]
outputs = tf.keras.layers.RNN(cells, stateful=True, return_sequences=True, return_state=False)(inputs)
outputs = tf.keras.layers.Dense(1)(outputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()
戻り値:
Model: "model_13"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_88 (InputLayer) [(128, 196, 2)] 0
_________________________________________________________________
rnn_61 (RNN) (128, 196, 120) 218880
_________________________________________________________________
dense_19 (Dense) (128, 196, 1) 121
=================================================================
Total params: 219,001
Trainable params: 219,001
Non-trainable params: 0
tf.keras.layers.RNN
とを使用したマルチレイヤー RNN tf.keras.layers.StackedRNNCells
:
inputs = tf.keras.layers.Input(batch_shape=(sz_batch, sz_sequence, sz_feature))
cells = [tf.keras.layers.GRUCell(n_units) for _ in range(n_layers)]
outputs = tf.keras.layers.RNN(tf.keras.layers.StackedRNNCells(cells),
stateful=True,
return_sequences=True,
return_state=False)(inputs)
outputs = tf.keras.layers.Dense(1)(outputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()
戻り値:
Model: "model_14"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_89 (InputLayer) [(128, 196, 2)] 0
_________________________________________________________________
rnn_62 (RNN) (128, 196, 120) 218880
_________________________________________________________________
dense_20 (Dense) (128, 196, 1) 121
=================================================================
Total params: 219,001
Trainable params: 219,001
Non-trainable params: 0