TFRecord
TL;DRバッチ予測を行う際、Google Cloud AI Platform はどのようにファイルを解凍しますか?
トレーニング済みの Keras モデルを Google Cloud AI Platform にデプロイしましたが、バッチ予測のファイル形式に問題があります。トレーニングのために、私は a を使用して次tf.data.TFRecordDataset
のようなリストを読み込んでいますがTFRecord
、これらはすべて正常に機能します。
def unpack_tfrecord(record):
parsed = tf.io.parse_example(record, {
'chunk': tf.io.FixedLenFeature([128, 2, 3], tf.float32), # Input
'class': tf.io.FixedLenFeature([2], tf.int64), # One-hot classification (binary)
})
return (parsed['chunk'], parsed['class'])
files = [str(p) for p in training_chunks_path.glob('*.tfrecord')]
dataset = tf.data.TFRecordDataset(files).batch(32).map(unpack_tfrecord)
model.fit(x=dataset, epochs=train_epochs)
tf.saved_model.save(model, model_save_path)
保存したモデルを Cloud Storage にアップロードし、AI Platform で新しいモデルを作成します。AI プラットフォームのドキュメントには、「gcloud ツールを使用したバッチ [サポート] JSON インスタンス文字列または TFRecord ファイルを含むテキスト ファイル (圧縮されている可能性があります)」( https://cloud.google.com/ai-platform/prediction/docs/overview#prediction_input_data)。しかし、TFRecord ファイルを提供すると、エラーが発生します。
("'utf-8' codec can't decode byte 0xa4 in position 1: invalid start byte", 8)
私の TFRecord ファイルには、エンコードされた Protobuf の束が含まれていますtf.train.Example
。私はunpack_tfrecord
AI Platform に関数を提供していないので、適切に解凍できないのは理にかなっていると思いますが、ここからどこに行くべきかについてはノードのアイデアがあります。データが大きすぎるため、JSON 形式の使用には興味がありません。