今天来个轻松一点儿的内容。
按批次训练的时候,想显示一个周期还有多久能执行结束,怎么办,使用tqdm
第1种使用方式——适用于列表形式的对象
for i in tqdm(range(len(dataset))):
inputs, targets = dataset[i]
if len(inputs) == 3:
spans, words, features = inputs
label, costs, true_ant = targets
assert costs.numpy() == 0, "costs error"
assert true_ant.numpy() == 0, "true_ant error"
assert label.numpy()== 1, "label error"
第2种使用方式——适用于迭代器形式的对象
# 检查训练时,返回元组长度是否都是2
total = len(dataloader)
update_num = int(total*0.01)
with tqdm(total=total) as pbar:
for batch_i, (m_idx, n_pairs_l, batch) in enumerate(
zip(mentions_idx, n_pairs, dataloader)
):
assert len(batch) == 2
if batch_i % update_num == 0:
pbar.update(update_num)
有种启动系统的感觉,尽情玩耍吧。