1

トレーニング済みモデルのエクスポート後、BatchNorm レイヤーは引き続き存在します。私は、次の 2 つの理由から、推論のためにこれらを削除する必要があることをどこかで読みました。

  1. ネットワーク出力が間違っている可能性があります
  2. ネットワーク全体の高速化

さて、私は 1. には疑問がありますが、2 番目の事実は論理的に聞こえるので、私の質問は次のとおりです。

環境: Tensorflow GitHub のモデルで、Tensorflow 1.15.3 でトレーニングされています。

使用済みの輸出:

python deeplab/export_model.py \
--num_classes=2 --model_variant="mobilenet_v3_large_seg" \
--dataset="123" \
--checkpoint_path=training \
--crop_size=384 \
--crop_size=384 \
--export_path=graph.pb

ネットワーク グラフの抜粋:

(<tf.Tensor 'MobilenetV3/MobilenetV3/input:0' shape=(1, 768, 768, 3) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/weights:0' shape=(3, 3, 3, 16) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/weights/read:0' shape=(3, 3, 3, 16) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/Conv2D:0' shape=(1, 384, 384, 16) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/gamma:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/gamma/read:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/beta:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/beta/read:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/moving_mean:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/moving_mean/read:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/moving_variance:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/moving_variance/read:0' shape=(16,) dtype=float32>,)
(<tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:0' shape=(1, 384, 384, 16) dtype=float32>, <tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:1' shape=(16,) dtype=float32>, <tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:2' shape=(16,) dtype=float32>, <tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:3' shape=(16,) dtype=float32>, <tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:4' shape=(16,) dtype=float32>, <tf.Tensor 'MobilenetV3/Conv/BatchNorm/FusedBatchNormV3:5' shape=<unknown> dtype=float32>)
(<tf.Tensor 'MobilenetV3/Conv/hard_swish/add/y:0' shape=() dtype=float32>,)
4

1 に答える 1