0

keras で交互更新ルールを使用したいと思います。つまり、バッチごとに通常のグラデーション ベースのステップを呼び出し、次にカスタム ステップを呼び出したいと考えています。

オプティマイザーまたはコールバックを継承して実装することを考えました (そしてオンバッチ呼び出しを使用します)。ただし、両方ともバッチデータとバッチラベルが不足しているため (両方が必要です)、どちらも行いません。

keras を使用してカスタムの交互更新を実装する方法について何か考えはありますか?

必要に応じて、keras フレームワーク (model.fit、model.predict .. を使用) でラップされたプロジェクトを使用し続けることができる限り、tensorflow 固有のメソッドを直接呼び出してもかまいません。

4

1 に答える 1

-1

カスタムコールバックを作成してみてください

import keras.callbacks as callbacks

class JSONMetrics(callbacks.Callback):

_model      = None
_each_epoch = None
_metrics    = None
_epoch      = None
_file_json  = None 

def __init__(self,model,each_epoch,logger=None):

    self._file_json = "file_log.json"
    self._model     = model
    self._each_epoch= each_epoch
    self._epoch     = 0
    self._metrics   = {'loss':[], 'acc':[]}

def on_epoch_begin(self, epoch, logs):
    # print('Epoch {0} begin'.format(epoch))
    try:
        with open(self._file_json, 'r') as f:   
            self._metrics = json.load(f)

def on_epoch_end(self, epoch, logs):
    self._logger.info('Nemesis: Epoch {0} end'.format(epoch))

    self._metrics['loss'].append(logs.get('loss'))
    self._metrics['acc'].append(logs.get('acc'))
    with open(self._file_json, 'w') as f:
        data = json.dump(self._metrics, f)

    if self._epoch % self._each_epoch == 0:

        file_name = 'weights%08d.h5' % self._epoch
        #print('Saving weights at {0} file'.format(file_name))
        self._model.save_weights(file_name)

    self._epoch += 1

self.model を呼び出して問題を解決し、たとえば acc と loss を保存できます。

于 2017-07-25T00:22:14.607 に答える