代码来自CSS文章代码
用来展现yes/no,other,number等类型分数
def evaluate(model, dataloader, qid2type):
score = 0
upper_bound = 0
score_yesno = 0
score_number = 0
score_other = 0
total_yesno = 0
total_number = 0
total_other = 0
for v, q, a, b, qids, _ in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"):
v = Variable(v, requires_grad=False).cuda()
q = Variable(q, requires_grad=False).cuda()
pred, loss, _ = model(v, q, None, None, None)
batch_score = compute_score_with_logits(pred, a.cuda()).cpu().numpy().sum(1)
score += batch_score.sum()
upper_bound += (a.max(1)[0]).sum()
qids = qids.detach().cpu().int().numpy()
for j in range(len(qids)):
qid = qids[j]
typ = qid2type[str(qid)]
if typ == 'yes/no':
score_yesno += batch_score[j]
total_yesno += 1
elif typ == 'other':
score_other += batch_score[j]
total_other += 1
elif typ == 'number':
score_number += batch_score[j]
total_number += 1
else:
print('Hahahahahahahahahahaha')
score = score / len(dataloader.dataset)
upper_bound = upper_bound / len(dataloader.dataset)
score_yesno /= total_yesno
score_other /= total_other
score_number /= total_number
eval_loss = loss
results = dict(
score=score,
upper_bound=upper_bound,
score_yesno=score_yesno,
score_other=score_other,
score_number=score_number,
)
return results, eval_loss
上面的验证的代码在以下train中使用
if run_eval:
model.train(False)
results, eval_loss = evaluate(model, eval_loader, qid2type)
results["epoch"] = epoch + 1
results["step"] = total_step
results["train_loss"] = total_loss
results["train_score"] = train_score
model.train(True)
eval_score = results["score"]
bound = results["upper_bound"]
yn = results['score_yesno']
other = results['score_other']
num = results['score_number']
# epoch: time:
logger.write('epoch %d, time: %.2f' % (epoch, time.time() - t))
# train_loss: score:
logger.write('\ttrain_loss: %.2f, score: %.2f' % (total_loss, train_score))
writer.add_scalars(args.result_path + 'Train_val_loss', {args.result_path + 'train_loss': total_loss.data.item()},epoch)
if run_eval:
# eval score:
logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound))
# yn score: other score: num score:
logger.write('\tyn score: %.2f other score: %.2f num score: %.2f' % (100 * yn, 100 * other, 100 * num))
if eval_score > best_eval_score:
model_path = os.path.join(output, 'model.pth')
torch.save(model.state_dict(), model_path)
best_eval_score = eval_score