3

Android アプリで Tensorflow アルゴリズムを使用したいと考えています。Tensorflow Android の例は、モデル定義と重み (*.pb ファイル) を含む GraphDef をダウンロードすることから始まります。これは、私の Scikit Flow アルゴリズム (Tensorflow の一部) からのものである必要があります。

一見すると、classifier.save('model/') と言うだけで簡単に思えますが、そのフォルダーに保存されるファイルは *.ckpt、*.def、そしてもちろん *.pb ではありません。代わりに、*.pbtxt とチェックポイント (終了なし) ファイルを処理する必要があります。

私はかなり長い間そこに立ち往生しています。何かをエクスポートするコード例を次に示します。

#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics

#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)

取得するファイルは次のとおりです。

  • チェックポイント
  • グラフ.pbtxt
  • model.ckpt-1.meta
  • model.ckpt-1-00000-of-00001
  • model.ckpt-200.meta
  • model.ckpt-200-00000-of-00001

私が見つけた多くの可能な回避策では、GraphDef を変数に含める必要があります (Scikit Flow の方法がわからない)。または、Scikit Flow を使用する必要がないように思われる Tensorflow セッション。

4

1 に答える 1

2

pb ファイルとして保存するには、構築されたグラフから graph_def を抽出する必要があります。あなたはそれを次のように行うことができます--

from tensorflow.python.framework import tensor_shape, graph_util
from tensorflow.python.platform import gfile
sess = tf.Session()
final_tensor_name = 'results:0'     #Replace final_tensor_name with name of the final tensor in your graph
#########Build your graph and train########
## Your tensorflow code to build the graph
###########################################

outpt_filename = 'output_graph.pb'
output_graph_def = sess.graph.as_graph_def()
with gfile.FastGFile(outpt_filename, 'wb') as f:
  f.write(output_graph_def.SerializeToString())

トレーニング済みの変数を定数に変換する場合 (ckpt ファイルを使用して重みをロードしないようにするため)、次を使用できます。

output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [final_tensor_name])

お役に立てれば!

于 2016-10-02T01:36:09.507 に答える