7

独自のイテレータを実装しています。リスト内の要素の合計量がわからないため、tqdm はプログレスバーを表示しません。"total=" は見た目が悪いので使いたくありません。むしろ、tqdm が合計を計算するために使用できる何かをイテレータに追加することをお勧めします。

class Batches:
    def __init__(self, batches, target_input):
        self.batches = batches
        self.pos = 0
        self.target_input = target_input

    def __iter__(self):
        return self

    def __next__(self):
        if self.pos < len(self.batches):
            minibatch = self.batches[self.pos]
            target = minibatch[:, :, self.target_input]
            self.pos += 1
            return minibatch, target
        else:
            raise StopIteration

    def __len__(self):
        return self.batches.len()

これは可能ですか?上記のコードに何を追加しますか...

以下のように tqdm を使用します。

for minibatch, target in tqdm(Batches(test, target_input)):

    output = lstm(minibatch)
    loss = criterion(output, target)
    writer.add_scalar('loss', loss, tensorboard_step)
4

2 に答える 2