1

テンソルフロー2.xでマイクロバッチ処理をどのように実装できますか? つまり、いくつかのバッチの勾配を蓄積し、これらの蓄積された勾配で重みを更新したいと考えています (これにより、実質的にバッチ サイズが累積ステップ * バッチ サイズに増加します)。

次のコードで試しました:

import numpy as np
import tensorflow as tf

class Model(tf.keras.Model):
    def __init__(self,  ):
        super().__init__()
    
        self.dense = tf.keras.layers.Dense(1)

    def call(self, inputs):
        return self.dense(inputs)


class Trainer:
    def __init__(self, model, num_accumulate):
        self.model = model
        self.num_accumulate = num_accumulate
        self.optimizer = tf.keras.optimizers.Adam()
        self.accumulated_gradients = None

    def _init_accumulated_gradients_maybe(self):
        if self.accumulated_gradients is None:
            self.accumulated_gradients = [tf.Variable(var, dtype=var.dtype, trainable=False) for var in self.model.trainable_weights]
            self._reset_gradients()

    def _reset_gradients(self):
        for grad in self.accumulated_gradients:
            grad.assign(tf.zeros_like(grad))

    def _accumulate_gradients(self, gradients):
        for acc_grad, grad in zip(self.accumulated_gradients, gradients):
            acc_grad.assign_add( grad / self.num_accumulate )

    def get_mae(self, targets, mean_pred):
        return tf.reduce_mean(tf.abs(targets - mean_pred))

    @tf.function
    def train_on_batch(self, dataset_iter):
        
        for _ in range(self.num_accumulate): # problematic
            inputs, target = next(dataset_iter)

            with tf.GradientTape() as tape:
                prediction = self.model(inputs, training=True)
                loss = self.get_mae(target, prediction)

            gradients = tape.gradient(loss, self.model.trainable_weights)

            self._init_accumulated_gradients_maybe()
            self._accumulate_gradients(gradients)
            gradients = self.accumulated_gradients

        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_weights))
        self._reset_gradients()

        return loss

class DataProvider:
    def __init__(self,  
                        batch_size: int = 1, 
                ):
        self.batch_size = batch_size
        self.in_data = np.random.rand(100,10)
        self.out_data = np.random.rand(100,1)

    def get_dataset(self):
        def generator():
            while True:
                yield (tf.constant(self.in_data, dtype=tf.float32), tf.constant(self.out_data, dtype=tf.float32))

        return tf.data.Dataset.from_generator(
                generator,
                output_types=(tf.float32, tf.float32),
                output_shapes=([None,10], [None,1])
                )


num_accumulate = 4
batch_size = 25
nSteps = 10

model = Model()
trainer = Trainer(model, num_accumulate)
dataset_iter = iter(DataProvider(batch_size).get_dataset())

for step in range(1, nSteps):
    trainer.train_on_batch(dataset_iter)

ただし、tf.range を使用するか、tf.function 装飾関数内で範囲を使用するかによって、2 つの異なる問題に遭遇しました。

  1. 範囲の使用: 提供されたミニ モデルで動作しますが、私のユース ケースではモデルがかなり大きく (2.6 Mio パラメータ)、このような勾配を累積すると、次のエラーが発生します。

2021-04-24 18:19:28.349940: W tensorflow/core/common_runtime/process_function_library_runtime.cc:733] マルチデバイス関数の最適化の失敗を無視: 期限を超えました: meta_optimizer が期限を超えました。

私の推測では、この部分を繰り返して一度だけ追加するのではなく、範囲を使用すると (tf.function がどのように機能するかを理解している限り)、すべての勾配累積ステップがグラフに追加されます。

  1. range を tf.range に置き換えると、次のエラーが発生します。
Traceback (most recent call last):
  File "/mydirectory/model/test_train copy.py", line 89, in <module>
    trainer.train_on_batch(dataset_iter)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 627, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 505, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2446, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2657, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3299, in bound_method_wrapper
    return wrapped_fn(*args, **kwargs)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    /mydirectory/model/test_train copy.py:40 train_on_batch  *
        for _ in tf.range(self.num_accumulate):
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:343 for_stmt
        _tf_range_for_stmt(
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:526 _tf_range_for_stmt
        _tf_while_stmt(
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:862 _tf_while_stmt
        _verify_loop_init_vars(init_vars, symbol_names)
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:119 _verify_loop_init_vars
        raise ValueError('"{}" must be defined before the loop.'.format(name))

    ValueError: "loss" must be defined before the loop.

したがって、勾配、損失、予測など、発生するすべての変数を初期化してから動作しますが、(私の使用例では) 非常に遅いのはなぜですか?

何が欠けていますか?どんな助けでも大歓迎です。

4

0 に答える 0