Android アプリで Tensorflow アルゴリズムを使用したいと考えています。Tensorflow Android の例は、モデル定義と重み (*.pb ファイル) を含む GraphDef をダウンロードすることから始まります。これは、私の Scikit Flow アルゴリズム (Tensorflow の一部) からのものである必要があります。
一見すると、classifier.save('model/') と言うだけで簡単に思えますが、そのフォルダーに保存されるファイルは *.ckpt、*.def、そしてもちろん *.pb ではありません。代わりに、*.pbtxt とチェックポイント (終了なし) ファイルを処理する必要があります。
私はかなり長い間そこに立ち往生しています。何かをエクスポートするコード例を次に示します。
#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics
#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)
取得するファイルは次のとおりです。
- チェックポイント
- グラフ.pbtxt
- model.ckpt-1.meta
- model.ckpt-1-00000-of-00001
- model.ckpt-200.meta
- model.ckpt-200-00000-of-00001
私が見つけた多くの可能な回避策では、GraphDef を変数に含める必要があります (Scikit Flow の方法がわからない)。または、Scikit Flow を使用する必要がないように思われる Tensorflow セッション。