5

tensorflow の API に既にバッチ正規化関数があることに気付きました。私が理解していないことの1つは、トレーニングとテストの間で手順を変更する方法ですか?

バッチ正規化は、トレーニング中とテスト中では異なる動作をします。具体的には、トレーニング中に固定平均と分散を使用します。

どこかに良いサンプルコードはありますか?いくつか見ましたが、スコープ変数で混乱しました

4

1 に答える 1

9

そうです、tf.nn.batch_normalizationバッチ正規化を実装するための基本的な機能を提供するだけです。追加のロジックを追加して、トレーニング中に移動平均と分散を追跡し、推論中にトレーニング済み平均と分散を使用する必要があります。非常に一般的な実装のこのを見ることができますが、使用しない簡単なバージョンは次のgammaとおりです。

  beta = tf.Variable(tf.zeros(shape), name='beta')
  moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean',
                                 trainable=False)
  moving_variance = tf.Variable(tf.ones(shape),
                                     name='moving_variance',
                                     trainable=False)
  control_inputs = []
  if is_training:
    mean, variance = tf.nn.moments(image, [0, 1, 2])
    update_moving_mean = moving_averages.assign_moving_average(
        moving_mean, mean, self.decay)
    update_moving_variance = moving_averages.assign_moving_average(
        moving_variance, variance, self.decay)
    control_inputs = [update_moving_mean, update_moving_variance]
  else:
    mean = moving_mean
    variance = moving_variance
  with tf.control_dependencies(control_inputs):
    return tf.nn.batch_normalization(
        image, mean=mean, variance=variance, offset=beta,
        scale=None, variance_epsilon=0.001)
于 2016-05-02T16:39:41.920 に答える