2021SC@SDUSC
我们继续分析train.py,上次提到了train()函数,是train.py中的一个核心函数,它主要用于对数据集进行训练和测试。
首先看看它的四个参数以及调用时传的实际参数:
def train(m,o,ds,args):
train(m,o,ds,args)
m是一个model类的模型,model类在之前已详细分析过。
o是main函数里,通过对m的参数、权重和偏置,和学习率以及0.9的冲量调用了torch.optim.SGD函数的结果。
ds是数据集dataset。
args是pargs类的参数。
loss = 0
ex = 0
trainorder = [('1',ds.t1_iter),('2',ds.t2_iter),('3',ds.t3_iter)]
loss是训练集损失数,ex是训练集增加数。trainorder是一个列表,里面存放了三个元素。这三个元素的意思在dataset类中,分别对data进行了迭代批处理。
self.t1_iter = data.Iterator(t1d,args.t1size,device=args.device,sort_key=lambda x:len(x.out),repeat=False,train=True)
self.t2_iter = data.Iterator(t2d,args.t2size,device