オブジェクト検出のための地域提案ネットワークに取り組んでいます。データを拡張したい場合は、画像と対応するバウンディング ボックス (拡大縮小、回転など) の両方を拡張する必要があります + 入力として画像マスクもあり、これも拡張する必要があります。そこで、データ拡張用のカスタム関数を作成しました。
パイプラインは次のとおりです。
tf.data.Dataset
読み込まれた画像、ラベル (バウンディング ボックスの座標)、およびマスクを含む Python ジェネレーターから作成します。
tf_dataset = tf.data.Dataset.from_generator(lambda: data_generator(img_paths, annotations),
output_shapes={'images': tf.TensorShape([None, None, 3]),
'labels': tf.TensorShape([None, 4]),
'masks': tf.TensorShape([None, None, 1])},
output_types={'images': tf.float32,
'labels': tf.int32,
'masks': tf.uint8})
if use_augument:
tf_dataset = augment_dataset(tf_dataset)
コードは追加の拡張を続けますが、拡張を追加するとメモリリークが発生し始めるため、この部分は重要です。
- 関数
augment_dataset(tf_dataset)
はデータセットの拡張を呼び出します。1 つの増強を除いてすべて正常に動作します (すべての増強はこのように行われますが、これだけが非常に高いメモリ リーク (RAM 使用量) を引き起こします):
dataset = dataset.map(
lambda x: tf.cond(
tf.keras.backend.random_uniform([], 0, 1) > 0.0,
lambda: aug.shrink(
x,
vertical=tf.keras.backend.random_uniform([], 0.8, 2),
horizontal=tf.keras.backend.random_uniform([], 0.8, 2)
),
lambda: x
), num_parallel_calls=8
)
- 関数は、比率
aug.shrink(tensor: tf.Tensor, vertical: float = 1, horizontal: float = 1)
に基づいて画像、ラベル、マスクのサイズを変更します。私のネットワークは、画像の動的サイズ (使用する理由) を受け入れ、ラベル (境界ボックス) の形式は次のとおりです。vertical
horizontal
tf.shape
(xcenter, ycenter, width, height)
def shrink(tensor: tf.Tensor, vertical: float = 1, horizontal: float = 1):
image = tensor['images']
label = tensor['labels']
mask = tensor['masks']
height = tf.cast(tf.shape(image)[0], tf.float32) / vertical
width = tf.cast(tf.shape(image)[1], tf.float32) / horizontal
height = tf.cast(height, tf.int32)
width = tf.cast(width, tf.int32)
# resize image
resize_image = tf.image.resize(image, (height, width))
resize_mask = tf.image.resize(mask, (height, width))
resize_mask = tf.cast(resize_mask, tf.uint8)
div_tensor = [horizontal, vertical, horizontal, vertical]
new_label = tf.math.divide(tf.cast(label, tf.float32), div_tensor)
new_label = tf.cast(new_label, tf.int32)
return {'images' : resize_image, 'labels' : new_label,'masks' : resize_mask}
サイズ変更にOpenCVとnumpyの組み合わせを使用したバージョンも試しましたtf.py_function
が、機能しませんでした。誰もそのような問題に遭遇しましたか?
ご協力ありがとうございました。