0

MNIST画像データセットから tfrecord を作成し、tfrecord を tf.data.dataset に変換します。「python3 tfrecord1.py」実行中は正常です。しかし、「mpirun -np 2 python3 tfrecord1.py」の実行中に DataLossError が発生しました。

私のコードに何か問題があるのか​​もしれません。

私のコンピューティング環境: ubuntu 20.04、tensorflow 2.6.0、horovod 0.23、32 CPU、GPU なし

threcord1.py

import tensorflow as tf
import horovod.tensorflow as hvd
from PIL import Image
import os, glob

hvd.init()

path = os.path.expanduser('~') # home directory
train_path = os.path.join(path, 'mnist_png/training') # mnist_png/training'
images = glob.glob(train_path + '/*/*.png')
tfrecord_filename = '/disfs/mnist.tfrecord'

# ----------- write tfrecord ---------------

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

writer = tf.io.TFRecordWriter(tfrecord_filename)

for image in images:
    img = open(image, 'rb').read()
    img_ = tf.image.decode_png(img)
    img_ = bytes(img)
    label = int (image.split('/')[5])
    feature = {
            'image' : _bytes_feature(img),
            'label' : _int64_feature(label)
            }
    example = tf.train.Example(features = tf.train.Features(feature=feature))
    writer.write(example.SerializeToString())
writer.close()

# --------- read tfrecord ----------
reader = tf.data.TFRecordDataset(tfrecord_filename)
feature_set = {
        'image' : tf.io.FixedLenFeature([], tf.string),
        'label' : tf.io.FixedLenFeature([], tf.int64)
        }

def _parse_function(exam_proto):
    feature_dict = tf.io.parse_single_example(exam_proto, feature_set)
    raw_image = tf.io.decode_jpeg(feature_dict['image'])
    raw_image = tf.image.resize(raw_image, [28,28])/255.0
    raw_image = tf.reshape(raw_image,[28,28,1])
    label = feature_dict['label']
    return (raw_image, label)

dataset = reader.map(_parse_function)
dataset = dataset.repeat().shuffle(len(images)).batch(128)
    
for batch, (images, labels) in enumerate(dataset.take(5)): # DataLossError
    print(batch,'\n',images,'\n',  labels)

ここに画像の説明を入力

4

0 に答える 0