そのため、ニューラル ネットワークのトレーニング可能なパラメーターを最適化するために、Keras の Model.fit() と低レベルの TF GradientTape の両方を試してみたところ、Keras バージョンが大幅に優れていることがわかりました。
最終的な MSE が である Keras 最適化バージョンのコード:
from tensorflow import keras
import tensorflow as tf
from sklearn.datasets import load_boston
X,y = load_boston(return_X_y=True)
X_tf = tf.cast(X, dtype=tf.float32)
model = keras.Sequential()
model.add(keras.layers.Dense(100, activation = 'relu', input_shape = (13,)), )
model.add(keras.layers.Dense(100, activation = 'relu'))
model.add(keras.layers.Dense(100, activation = 'relu'))
model.add(keras.layers.Dense(1, activation = 'linear'))
model.compile(optimizer = tf.keras.optimizers.Adam(0.01),
loss = tf.keras.losses.MSE
)
model.fit(X, y, epochs=1000)enter code here
ただし、次のコードに示すように、tf.GradientTape を使用して Keras モデルを最適化すると、次のようになります。
from tensorflow import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_boston
X,y = load_boston(return_X_y=True)
X_tf = tf.cast(X, dtype=tf.float32)
model = keras.Sequential()
model.add(keras.layers.Dense(100, activation = 'relu', input_shape = (np.shape(X)[1],)),
)
model.add(keras.layers.Dense(100, activation = 'relu'))
model.add(keras.layers.Dense(100, activation = 'relu'))
model.add(keras.layers.Dense(1, activation = 'linear'))
optimizer = tf.keras.optimizers.Adam(learning_rate = 0.01)
def loss_func(pred, target):
return tf.reduce_mean(tf.square(pred - target))
trainable_params = model.trainable_variables
def train_step():
with tf.GradientTape() as tape:
y_tild = model(X_tf)
loss = loss_func(y_tild, y)
grads = tape.gradient(loss, trainable_params)
optimizer.apply_gradients(zip(grads, trainable_params))
print("Loss : " + str(loss.numpy()))
epochs = 1000
for ii in range(epochs):
train_step()
Keras フィット バージョンの値は、GradientTape を使用して取得した値よりも実際の値に近いことがわかります。また、Gradient Tape の値も、異なる入力に対してあまり変化せず、平均を回避しましたが、Keras の値はより多くの多様性を示しました。
では、GradientTape 低レベル API を使用して、Keras 高レベル API と同等のパフォーマンスを得るにはどうすればよいでしょうか? 私の実装よりもはるかに優れている Model.fit が行っていることは何ですか? ソースコードを調べてみましたが、本質的に特定できませんでした。
前もって感謝します。