tf.contrib.learn.DNNLinearCombinedClassifier
tensorflow サービング用のインスタンスをエクスポートしようとしています。次のコードを実行する場合:
estimator.export(
export_path,
signature_fn=tf.contrib.learn.utils.export.classification_signature_fn)
次の警告が表示されます。
警告: tensorflow: use_deprecated_input_fn=True を指定して (tensorflow.contrib.learn.python.learn.estimators.estimator から) export を呼び出すことは非推奨であり、2016 年 9 月 23 日以降に削除されます。更新の手順: export によって受け入れられる input_fn の署名は、tf.Learn Estimator の train/evaluate によって使用されるものと一致するように変更されています。input_fn と input_feature_key は必須の引数になり、use_deprecated_input_fn はデフォルトで False になり、完全に削除されます。
問題は、今のところこの警告を無視してもいいですか? もう 1 つの質問ですが、クライアントのコードをどのように記述すればよいでしょうか? protobuf を正しく準備するにはどうすればよいですか?
mnist クライアントの場合、protobuf
次のように準備されていることがわかります。
request = predict_pb2.PredictRequest()
request.model_spec.name = 'mnist'
image, label = test_data_set.next_batch(1)
request.inputs['images'].CopyFrom(
tf.contrib.util.make_tensor_proto(image[0], shape=[1, image[0].size]))
contrib.learn
エスティメータで使用される特徴列に対して同じことを行うにはどうすればよいですか? たとえば、機能列が次のようになっているとしますか?
country = sparse_column_with_vocabulary_file("country", VOCAB_FILE)
age = real_valued_column("age")
click_bucket = bucketized_column(real_valued_column("historical_click_ratio"),
boundaries=[i/10. for i in range(10)])
country_x_click = crossed_column([country, click_bucket], 10)
feature_columns = set([age, click_bucket, country_x_click])
...そしてhttps://www.tensorflow.org/versions/r0.11/tutorials/wide/index.htmlinput_fn
などのチュートリアルの 1 つから、クライアントから送信する実際のデータが提供されます。
アップデート:
エクスポートとクライアントの組み合わせを実行しましたが、結果が正しくありません。エクスポート コードの一部は次のとおりです。
def my_classification_signature_fn(examples, unused_features, predictions):
"""Creates classification signature from given examples and predictions.
Args:
examples: `Tensor`.
unused_features: `dict` of `Tensor`s.
predictions: `Tensor` or dict of tensors that contains the classes tensor
as in {'classes': `Tensor`}.
Returns:
Tuple of default classification signature and empty named signatures.
Raises:
ValueError: If examples is `None`.
"""
if examples is None:
raise ValueError('examples cannot be None when using this signature fn.')
if isinstance(predictions, dict):
default_signature = exporter.classification_signature(
examples, classes_tensor=predictions['classes'])
else:
print examples
print predictions
default_signature = exporter.classification_signature(
examples, classes_tensor=predictions)
named_graph_signatures={
'inputs': exporter.generic_signature({'x_values': examples}),
'outputs': exporter.generic_signature({'preds': predictions})}
return default_signature, named_graph_signatures
def export_input_fn(df,feature_defs,batch_size):
input_features = input_fn(df, feature_defs)
input_features["input_feature_dummy"] = tf.constant(np.array([["__SOME_VAL__" for __ in range(len(df.columns))] for _ in range(batch_size) ]),
dtype=tf.string,
shape=[batch_size, len(df.columns)])
return input_features,None
model.export(
"export_8",
input_fn=lambda : export_input_fn(extdf,training_run_data["features"],20),
input_feature_key="input_feature_dummy",
signature_fn=my_classification_signature_fn,
use_deprecated_input_fn=False)
クライアントコードの一部は次のとおりです。
pred_df = pd.read_csv("{}/client_test.csv".format(work_dir)).sort_index(axis=1)
host, port = hostport.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
result_counter = _ResultCounter(num_tests, concurrency)
x_values = numpy.array([["__SOME_VAL__" for __ in range(len(pred_df.columns))] for _ in range(20)])
labels = pred_df["__label"].values
input_values = pred_df.astype('str').values
input_size = min(20, input_values.shape[0])
x_values[0:input_size] = input_values[0:input_size]
labels_test = numpy.zeros(20)
labels_test[0:input_size] = labels[0:input_size]
request = predict_pb2.PredictRequest()
request.model_spec.name = 'capture_process'
request.inputs['x_values'].CopyFrom(
tf.contrib.util.make_tensor_proto(x_values, shape=x_values.shape))
result_counter.throttle()
result_future = stub.Predict.future(request, 5.0) # 5 seconds
result_future.add_done_callback(
_create_rpc_callback(labels_test, result_counter))
ご覧のとおり、pandas データフレームから抽出されたすべての入力列を という 1 つのテンソル プロトとして送信しようとしていますx_values
。これは正しいですか、それとも特徴列ごとに入力を作成する必要がありますか?