私は、自分のデータ (2 つのクラス、したがって、ピクセルごとのバイナリ分類) のセマンティック セグメンテーションのために、CNN アーキテクチャ (事前トレーニング済みの VGG16 モデルを使用した FCN-8s モデル) の実装に取り組んでいます。
私がこれをどのように行うつもりかは次のとおりです。
- 事前トレーニング済みのモデルに重みをロードします
- FCN に変換するための追加の上位レイヤーの追加/削除
- 事前トレーニング済みモデルの下位レイヤーをフリーズします (トレーニング フェーズ中に更新しないようにするため)。
- 特定のデータセットでネットワークをトレーニングする
これが正しいと仮定すると、テンソルフロー モデルの下位レイヤーをフリーズするにはどうすればよいですか? (特定の実装の詳細を探しています) TensorFlow チュートリアルでのインセプションの再トレーニングを見ましたが、まだよくわかりません。
これは私が念頭に置いているワークフローです:
既存の事前トレーニング済みモデルでデータを実行し、トレーニングせずに機能出力を抽出します。(どうやって?)
これらの機能の出力を、上位層を含む別のネットワークにフィードし、トレーニングに取り掛かります。
どんな提案も役に立ちます!
そうでなければ、私が間違っているとしたら、これをどのように考えればよいでしょうか?
アップデート:
以下のchasp255の提案を取り上げ、モデルの下位レイヤーを「フリーズ」するためにtf.stop_gradientを使用しようとしました。明らかに、私の実装には何か問題があります。可能な代替/提案?
このモデルは、FCN (セマンティック セグメンテーション用) 論文に基づいて構築されています。logits
モデル アーキテクチャ、つまり機能から抽出し、最初にloss
関数に直接入力して、softmax 分類器で最小化します。(ピクセルごとの分類)deconv_1
は、形状の[batch, h, w, num_classes] = [1, 750, 750, 2]
実装のロジット テンソルです。
logits = vgg_fcn.deconv_1
stopper = tf.stop_gradient(logits, 'stop_gradients')
loss = train_func.loss(stopper, labels_placeholder, 2)
with tf.name_scope('Optimizer'):
train_op = train_func.training(loss, FLAGS.learning_rate)
with tf.name_scope('Accuracy'):
eval_correct = train_func.accuracy_eval(logits, labels_placeholder)
accuracy_summary = tf.scalar_summary('Accuracy', eval_correct)
次に、これらのグラフ操作を次のように実行します。
_, acc, loss_value = sess.run([train_op,eval_correct, loss], feed_dict=feed_dict)
このようにトレーニング サイクルを実行すると、損失値の最適化が行われません。これは、tf.stop_gradient
Op.
詳細については、以下の私の損失関数:
def loss(logits, labels, num_classes):
logits = tf.reshape(logits, [-1, num_classes])
#epsilon = tf.constant(value=1e-4)
#logits = logits + epsilon
labels = tf.to_int64(tf.reshape(labels, [-1]))
print ('shape of logits: %s' % str(logits.get_shape()))
print ('shape of labels: %s' % str(labels.get_shape()))
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='Cross_Entropy')
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='xentropy_mean')
tf.add_to_collection('losses', cross_entropy_mean)
loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
return loss