転移学習では、ネットワークを特徴抽出器として使用して特徴のデータセットを作成し、その上で別の分類器 (SVM など) をトレーニングします。
tf.contrib.data
Dataset API ( ) とを使用してこれを実装したいdataset.map()
:
# feature_extractor will create a CNN on top of the given tensor
def features(feature_extractor, ...):
dataset = inputs(...) # This creates a dataset of (image, label) pairs
def map_example(image, label):
features = feature_extractor(image, trainable=False)
# Leaving out initialization from a checkpoint here...
return features, label
dataset = dataset.map(map_example)
return dataset
データセットの反復子を作成するときに、これを行うと失敗します。
ValueError: Cannot capture a stateful node by value.
これは本当です。ネットワークのカーネルとバイアスは変数であり、したがってステートフルです。この特定の例では、そうである必要はありません。
tf.Variable
Ops、特にオブジェクトをステートレスにする方法はありますか?
私が使用しているのでtf.layers
、定数として単純に作成することはできず、設定しても定数は作成されませんが、変数がコレクションtrainable=False
に追加されることはありません。GraphKeys.TRAINABLE_VARIABLES