独自のイテレータを実装しています。リスト内の要素の合計量がわからないため、tqdm はプログレスバーを表示しません。"total=" は見た目が悪いので使いたくありません。むしろ、tqdm が合計を計算するために使用できる何かをイテレータに追加することをお勧めします。
class Batches:
def __init__(self, batches, target_input):
self.batches = batches
self.pos = 0
self.target_input = target_input
def __iter__(self):
return self
def __next__(self):
if self.pos < len(self.batches):
minibatch = self.batches[self.pos]
target = minibatch[:, :, self.target_input]
self.pos += 1
return minibatch, target
else:
raise StopIteration
def __len__(self):
return self.batches.len()
これは可能ですか?上記のコードに何を追加しますか...
以下のように tqdm を使用します。
for minibatch, target in tqdm(Batches(test, target_input)):
output = lstm(minibatch)
loss = criterion(output, target)
writer.add_scalar('loss', loss, tensorboard_step)