Nvidia DIGITS を使用して LeNet-gray-28x28 画像検出 Tensorflow モデルをトレーニングしたところ、期待どおりの結果が得られました。ここで、DIGITS 以外のいくつかの画像を分類する必要があり、トレーニングしたモデルを使用したいと考えています。
そこで、DIGITS で使用される LeNet モデルを取得し、それを使用するクラスを作成します。
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tflearn
from tflearn.layers.core import input_data
class LeNetModel():
def gray28(self, nclasses):
x = input_data(shape=[None, 28, 28, 1])
# scale (divide by MNIST std)
# x = x * 0.0125
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_initializer=tf.contrib.layers.xavier_initializer(),
weights_regularizer=slim.l2_regularizer(0.0005)):
model = slim.conv2d(x, 20, [5, 5], padding='VALID', scope='conv1')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool1')
model = slim.conv2d(model, 50, [5, 5], padding='VALID', scope='conv2')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='pool2')
model = slim.flatten(model)
model = slim.fully_connected(model, 500, scope='fc1')
model = slim.dropout(model, 0.5, is_training=False, scope='do1')
model = slim.fully_connected(model, nclasses, activation_fn=None, scope='fc2')
return tflearn.DNN(model)
DIGITS からモデルをダウンロードし、(別のファイルで) を使用してインスタンス化します。
self.ballmodel = LeNetModel().gray28(2)
self.ballmodel.load("src/perftrack/prototype/models/ball/snapshot_5.ckpt")
しかし、スクリプトを起動すると、次の例外が発生します。
2017-11-26 14:55:50.330524: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv1/biases not found in checkpoint
2017-11-26 14:55:50.330948: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key Global_Step not found in checkpoint
2017-11-26 14:55:50.331270: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key is_training not found in checkpoint
2017-11-26 14:55:50.331564: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv2/weights not found in checkpoint
2017-11-26 14:55:50.332823: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv1/weights not found in checkpoint
2017-11-26 14:55:50.332891: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key conv2/biases not found in checkpoint
2017-11-26 14:55:50.333620: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc2/weights not found in checkpoint
2017-11-26 14:55:50.334021: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc1/weights not found in checkpoint
2017-11-26 14:55:50.334173: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc1/biases not found in checkpoint
2017-11-26 14:55:50.334431: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Key fc2/biases not found in checkpoint
...
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Key conv1/biases not found in checkpoint
[[Node: save_1/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_1/tensor_names, save_1/RestoreV2_1/shape_and_slices)]]
[[Node: save_1/RestoreV2_1/_19 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_38_save_1/RestoreV2_1", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
したがって、https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/inspect_checkpoint.pyスクリプトを使用して、チェックポイントに含まれるキー名を検査すると、次のような結果が得られます。
model/conv1/biases
model/conv2/weights
...
だから私は自分のネットワークを書き直して、モデル/プレフィックスを手動で追加します:
model = slim.conv2d(x, 20, [5, 5], padding='VALID', scope='model/conv1')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='model/pool1')
model = slim.conv2d(model, 50, [5, 5], padding='VALID', scope='model/conv2')
model = slim.max_pool2d(model, [2, 2], padding='VALID', scope='model/pool2')
model = slim.flatten(model)
model = slim.fully_connected(model, 500, scope='model/fc1')
model = slim.dropout(model, 0.5, is_training=False, scope='model/do1')
model = slim.fully_connected(model, nclasses,
欠落しているキーの警告の一部は修正されますが、次のようになります。
- これは正しい修正方法ではないと感じています
- 2 つのキーを修正できません。
- Global_Step (チェックポイントに global_step キーがあります)
- is_training (それが何かはわかりません)
私の質問は、ネットワークでこれらのキー名を再定義して、チェックポイントで見つけたものと一致させるにはどうすればよいですか?