3

Attention-OCRモデルを OpenCV-DNN でラップして、推論時間を増やしたいと考えています。公式TFモデルレポのTFコードを使用しています。

TF モデルを OpenCV-DNN でラップするには、このコードを参照しています。TF モデルを読み取るにはcv2.dnn.readNetFromTensorflow()、「凍結グラフ」と「グラフ構造」が必要です。

このコード スニペットを使用して、メタ チェックポイント ファイルから構造をインポートし、グラフ構造をファイルに保存し.pbtxtます。

# load graph from meta file
tf.reset_default_graph()  
imported_meta = tf.train.import_meta_graph("attention_ocr_2017_08_09/model_demo_inference.ckpt.meta")

# restore graph structure, variables in session's graph
sess = tf.Session()
imported_meta.restore(sess, 'attention_ocr_2017_08_09/model_demo_inference.ckpt')
# write graph structure to a pbtxt file
tf.train.write_graph(sess.graph_def, './', 'train_attention.pbtxt', as_text=True)

グラフを固定するためのコードは次のとおりです。

from tensorflow.python.tools import freeze_graph
freeze_graph.freeze_graph('train_attention.pbtxt', '', False, \
                          'attention_ocr_2017_08_09/model_demo_inference.ckpt', \
                          'AttentionOcr_v1_1/Softmax', \
                          'save/restore_all', 'save/Const:0', 'frozen_model.pb', True, "")

最終的なコードでは、関数内でpbtxtおよびpbファイルを使用しcv2.dnn.readNetFromTensorflow()ます。

# Wrap TF model in OpenCV DNN
import cv2

FROZEN_GRAPH = "frozen_model.pb"
PB_TXT = "train_attention.pbtxt"

img = cv2.imread('testdata/fsns_train_00.png')
blob = cv2.dnn.blobFromImage(img,1)

net = cv2.dnn.readNetFromTensorflow(FROZEN_GRAPH, PB_TXT)
out = net.forward()
out

発生したエラーは次のとおりです。

---------------------------------------------------------------------------
error                                     Traceback (most recent call last)
<ipython-input-128-09e46e8b88ed> in <module>
      9 blob = cv2.dnn.blobFromImage(img,1)
     10 
---> 11 net = cv2.dnn.readNetFromTensorflow(FROZEN_GRAPH, PB_TXT)
     12 out = net.forward()
     13 out

error: OpenCV(4.0.0) /Users/travis/build/skvark/opencv-python/opencv/modules/dnn/src/
tensorflow/tf_io.cpp:54: error: (-2:Unspecified error) 
FAILED: ReadProtoFromTextFile(param_file, param). 
Failed to parse GraphDef file: train_attention.pbtxt in function 'ReadTFNetParamsFromTextFileOrDie'

注:出力ノード名は、以下を使用して生成されたグラフ内のテンソルのリストを見て手動で設定されます。

# get names of all tensors
def get_names(graph=sess.graph):
    return [t.name for op in graph.get_operations() for t in op.values()]

l1 = get_names()
for ele in l1:
    print(ele)

SO コミュニティから提供されたヘルプに感謝します。

4

1 に答える 1