0

TFRecordTL;DRバッチ予測を行う際、Google Cloud AI Platform はどのようにファイルを解凍しますか?

トレーニング済みの Keras モデルを Google Cloud AI Platform にデプロイしましたが、バッチ予測のファイル形式に問題があります。トレーニングのために、私は a を使用して次tf.data.TFRecordDatasetのようなリストを読み込んでいますがTFRecord、これらはすべて正常に機能します。

def unpack_tfrecord(record):
    parsed = tf.io.parse_example(record, {
        'chunk': tf.io.FixedLenFeature([128, 2, 3], tf.float32),  # Input
        'class': tf.io.FixedLenFeature([2], tf.int64),            # One-hot classification (binary)
    })

    return (parsed['chunk'], parsed['class'])

files = [str(p) for p in training_chunks_path.glob('*.tfrecord')]
dataset = tf.data.TFRecordDataset(files).batch(32).map(unpack_tfrecord)
model.fit(x=dataset, epochs=train_epochs)
tf.saved_model.save(model, model_save_path)

保存したモデルを Cloud Storage にアップロードし、AI Platform で新しいモデルを作成します。AI プラットフォームのドキュメントには、「gcloud ツールを使用したバッチ [サポート] JSON インスタンス文字列または TFRecord ファイルを含むテキスト ファイル (圧縮されている可能性があります)」( https://cloud.google.com/ai-platform/prediction/docs/overview#prediction_input_data)。しかし、TFRecord ファイルを提供すると、エラーが発生します。

("'utf-8' codec can't decode byte 0xa4 in position 1: invalid start byte", 8)

私の TFRecord ファイルには、エンコードされた Protobuf の束が含まれていますtf.train.Example。私はunpack_tfrecordAI Platform に関数を提供していないので、適切に解凍できないのは理にかなっていると思いますが、ここからどこに行くべきかについてはノードのアイデアがあります。データが大きすぎるため、JSON 形式の使用には興味がありません。

4

1 に答える 1

0

これが最善の方法かどうかはわかりませんが、TF 2.x の場合は次のようにすることができます。

import tensorflow as tf

def make_serving_input_fn():
    # your feature spec
    feature_spec = {
        'chunk': tf.io.FixedLenFeature([128, 2, 3], tf.float32),  
        'class': tf.io.FixedLenFeature([2], tf.int64),
    }

    serialized_tf_examples = tf.keras.Input(
        shape=[], name='input_example_tensor', dtype=tf.string)

    examples = tf.io.parse_example(serialized_tf_examples, feature_spec)

    # any processing 
    processed_chunks = tf.map_fn(
        <PROCESSING_FN>, 
        examples['chunk'], # ?
        dtype=tf.float32)

    return tf.estimator.export.ServingInputReceiver(
        features={<MODEL_FIRST_LAYER_NAME>: processed_chunks},
        receiver_tensors={"input_example_tensor": serialized_tf_examples}
    )


estimator = tf.keras.estimator.model_to_estimator(
    keras_model=model,
    model_dir=<ESTIMATOR_SAVE_DIR>)

estimator.export_saved_model(
    export_dir_base=<WORKING_DIR>,
    serving_input_receiver_fn=make_serving_input_fn)
于 2020-10-20T23:04:17.420 に答える