1

JAXライブラリとその小さなニューラルネットワークサブモジュール「Stax」を使用して、ニューラルネットワークを実装およびトレーニングしようとしています。このライブラリにはバイナリ クロス エントロピーの実装が付属していないため、独自に作成しました。

def binary_cross_entropy(y_hat, y):
    bce = y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat)
    return jnp.mean(-bce)

単純なニューラル ネットワークを実装して MNIST でトレーニングしたところ、得られた結果の一部に疑いを持ち始めました。そこで、Keras で同じセットアップを実装したところ、すぐにまったく異なる結果が得られました。同じデータで同じ方法でトレーニングされた同じモデルは、JAX では約 50% でしたが、Keras では 90% のトレーニング精度が得られました。最終的に、数値的に不安定であると思われるクロスエントロピーの単純な実装に問題の一部を突き止めました。この投稿と見つけたこのコードに従って、次の新しいバージョンを作成しました。

def binary_cross_entropy_stable(y_hat, y):
    y_hat = jnp.clip(y_hat, 0.000001, 0.9999999)
    logits = jnp.log(y_hat/(1 - y_hat))
    max_logit = jnp.clip(logits, 0, None)
    bces = logits - logits * y + max_logit + jnp.log(jnp.exp(-max_logit) + jnp.exp(-logits - max_logit))
    return jnp.mean(bces)

これは少しうまくいきます。現在、私の JAX 実装は最大 80% のトレーニング精度を達成していますが、それでも Keras の 90% よりもはるかに低いです。私が知りたいのは、何が起こっているのですか?2 つの実装が同じように動作しないのはなぜですか?

以下では、2 つの実装を 1 つのスクリプトに要約しました。このスクリプトでは、JAX と Keras で同じモデルを実装しています。両方を同じ weightsで初期化し、各モデルの同じデータである MNIST からの 1000 データポイントで 10 ステップのフルバッチ勾配降下法を使用してトレーニングします。JAX は 80% のトレーニング精度で終了しますが、Keras は 90% で終了します。具体的には、次の出力が得られます。

Initial Keras accuracy: 0.4350000023841858
Initial JAX accuracy: 0.435
Final JAX accuracy: 0.792
Final Keras accuracy: 0.9089999794960022
JAX accuracy (Keras weights): 0.909
Keras accuracy (JAX weights): 0.7919999957084656

実際、条件を少し変えると (異なるランダムな初期重みまたは異なるトレーニング セットを使用して)、50% の JAX 精度と 90% の Keras 精度が得られることがあります。

最後に重みを交換して、トレーニングから得られた重みが実際に問題であり、ネットワーク予測の実際の計算や精度の計算方法とは関係がないことを確認します。

コード:

import numpy as np

import jax
from   jax import jit, grad
from   jax.experimental import stax, optimizers
import jax.numpy as jnp

import keras
import keras.datasets.mnist

def binary_cross_entropy(y_hat, y):
    bce = y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat)
    return jnp.mean(-bce)

def binary_cross_entropy_stable(y_hat, y):
    y_hat = jnp.clip(y_hat, 0.000001, 0.9999999)
    logits = jnp.log(y_hat/(1 - y_hat))
    max_logit = jnp.clip(logits, 0, None)
    bces = logits - logits * y + max_logit + jnp.log(jnp.exp(-max_logit) + jnp.exp(-logits - max_logit))
    return jnp.mean(bces)

def binary_accuracy(y_hat, y):
    return jnp.mean((y_hat >= 1/2) == (y >= 1/2))

########################################
#                                      #
#          Create dataset              #
#                                      #
########################################

input_dimension = 784

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data(path="mnist.npz")

xs = np.concatenate([x_train, x_test])
xs = xs.reshape((70000, 784))

ys = np.concatenate([y_train, y_test])
ys = (ys >= 5).astype(np.float32)
ys = ys.reshape((70000, 1))

train_xs = xs[:1000]
train_ys = ys[:1000]

########################################
#                                      #
#           Create JAX model           #
#                                      #
########################################

jax_initializer, jax_model = stax.serial(
    stax.Dense(1000),
    stax.Relu,
    stax.Dense(1),
    stax.Sigmoid
)
rng_key = jax.random.PRNGKey(0)

_, initial_jax_weights = jax_initializer(rng_key, (1, input_dimension))

########################################
#                                      #
#         Create Keras model           #
#                                      #
########################################

initial_keras_weights  = [*initial_jax_weights[0], *initial_jax_weights[2]]

keras_model = keras.Sequential([
    keras.layers.Dense(1000, activation="relu"),
    keras.layers.Dense(1, activation="sigmoid")
])

keras_model.compile(
    optimizer=keras.optimizers.SGD(learning_rate=0.01),
    loss=keras.losses.binary_crossentropy,
    metrics=["accuracy"]
)

keras_model.build(input_shape=(1, input_dimension))

keras_model.set_weights(initial_keras_weights)

if __name__ == "__main__":

########################################
#                                      #
#      Compare untrained models        #
#                                      #
########################################

    initial_keras_predictions = keras_model.predict(train_xs, verbose=0)
    initial_jax_predictions   = jax_model(initial_jax_weights, train_xs)

    _, keras_initial_accuracy = keras_model.evaluate(train_xs, train_ys, verbose=0)
    jax_initial_accuracy = binary_accuracy(jax_model(initial_jax_weights, train_xs), train_ys)

    print("Initial Keras accuracy:", keras_initial_accuracy)
    print("Initial JAX accuracy:", jax_initial_accuracy)

########################################
#                                      #
#           Train JAX model            #
#                                      #
########################################
    
    L = jit(binary_cross_entropy_stable)
    gradL = jit(grad(lambda w, x, y: L(jax_model(w, x), y)))
    opt_init, opt_apply, get_params = optimizers.sgd(0.01)
    network_state = opt_init(initial_jax_weights)
    for _ in range(10):
        wT = get_params(network_state)
        gradient = gradL(wT, train_xs, train_ys)
        network_state = opt_apply(
            0,
            gradient,
            network_state
        )

    final_jax_weights = get_params(network_state)
    final_jax_training_predictions = jax_model(final_jax_weights, train_xs)
    final_jax_accuracy = binary_accuracy(final_jax_training_predictions, train_ys)

    print("Final JAX accuracy:", final_jax_accuracy)

########################################
#                                      #
#         Train Keras model            #
#                                      #
########################################

    for _ in range(10):
        keras_model.fit(
            train_xs,
            train_ys,
            epochs=1,
            batch_size=1000,
            verbose=0
        )

    final_keras_loss, final_keras_accuracy = keras_model.evaluate(train_xs, train_ys, verbose=0)

    print("Final Keras accuracy:", final_keras_accuracy)

########################################
#                                      #
#            Swap weights              #
#                                      #
########################################

    final_keras_weights = keras_model.get_weights()
    final_keras_weights_in_jax_format = [
        (final_keras_weights[0], final_keras_weights[1]),
        tuple(),
        (final_keras_weights[2], final_keras_weights[3]),
        tuple()
    ]
    jax_accuracy_with_keras_weights = binary_accuracy(
        jax_model(final_keras_weights_in_jax_format, train_xs),
        train_ys
    )
    print("JAX accuracy (Keras weights):", jax_accuracy_with_keras_weights)

    final_jax_weights_in_keras_format = [*final_jax_weights[0], *final_jax_weights[2]]
    keras_model.set_weights(final_jax_weights_in_keras_format)
    _, keras_accuracy_with_jax_weights = keras_model.evaluate(train_xs, train_ys, verbose=0)
    print("Keras accuracy (JAX weights):", keras_accuracy_with_jax_weights)

057 行目の PRNG シードを、異なる初期重みを使用して実験を実行する以外の値に変更してみてください。

4

1 に答える 1