2021SC@SDUSC
我们继续看main(args)的那个循环。
for e in range(starte,args.epochs):
print("epoch ",e,"lr",o.param_groups[0]['lr'])
train(m,o,ds,args)
vloss = evaluate(m,ds,args)
if args.lrwarm:
update_lr(o,args,e)
print("Saving model")
torch.save(m.state_dict(),args.save+"/"+str(e)+".vloss-"+str(vloss)[:8]+".lr-"+str(o.param_groups[0]['lr']))
if vloss > lastloss:
if args.lrdecay:
print("decay lr")
o.param_groups[0]['lr'] *= 0.5
lastloss = vloss
上次我们分析了train函数,接下来我们分析evaluate函数。这个函数主要对数据集进行评估操作。
def evaluate(m,ds,args):
print("Evaluating",end="\t")
m.eval()
loss = 0
ex = 0
m、ds、args参数和train的函数参数相同,此处不再重复。然后对m变字符串。