关键词短语生成的无监督方法13——Train.py

2021SC@SDUSC

一、源码解析

继上周对Train()的分析,我们接着来分析evaluate()函数。

def evaluate(iterator):
    #不启用Batch Normalization和Dropout
    model.eval()
    #定义损失值
    epoch_loss = 0
    #使不进行计算图构建
    with torch.no_grad():

代码分段解析:
此段代码首先设置不启用Batch Normalization和Dropout,声明初始损失值,在不进行计算图构建的前提下开始估值操作。
其中,with torch.no_grad()函数,指在使用pytorch时,并不是所有的操作都需要进行计算图的生成,即计算过程的构建,以便梯度反向传播等操作。对于tensor的计算操作,默认是要进行计算图的构建,在这种情况下,可以使用 with torch.no_grad():,强制之后的内容不进行计算图构建。
(1)使用with torch.no_grad():

with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))        
print(outputs)

运行结果:

Accuracy of the network on the 10000 test images: 55 %
tensor([[-2.9141, -3.8210,  2.1426,  3.0883,  2.6363,  2.6878,  2.8766,  0.3396,
         -4.7505, -3.8502],
        [-1.4012, -4.5747,  1.8557,  3.8178,  1.1430,  3.9522, -0.4563,  1.2740,
         -3.7763, -3.3633],
        [ 1.3090,  0.1812,  0.4852,  0.1315,  0.5297, -0.3215, -2.0045,  1.0426,
         -3.2699, -0.5084],
        [-0.5357, -1.9851, -0.2835, -0.3110,  2.6453,  0.7452, -1.4148,  5.6919,
         -6.3235, -1.6220]])

此时的outputs没有属性。
(2)不使用with torch.no_grad()

for data in testloader:
    images, labels = data
    outputs = net(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))
print(outputs)

运行结果:

Accuracy of the network on the 10000 test images: 55 %
tensor([[-2.9141, -3.8210,  2.1426,  3.0883,  2.6363,  2.6878,  2.8766,  0.3396,
         -4.7505, -3.8502],
        [-1.4012, -4.5747,  1.8557,  3.8178,  1.1430,  3.9522, -0.4563,  1.2740,
         -3.7763, -3.3633],
        [ 1.3090,  0.1812,  0.4852,  0.1315,  0.5297, -0.3215, -2.0045,  1.0426,
         -3.2699, -0.5084],
        [-0.5357, -1.9851, -0.2835, -0.3110,  2.6453,  0.7452, -1.4148,  5.6919,
         -6.3235, -1.6220]], grad_fn=<AddmmBackward>)

此时有grad_fn=< AddmmBackward >属性,表示,计算的结果在一计算图当中,可以进行梯度反传等操作。但是,两者计算的结果实际上是没有区别的。

        for i, (src, trg) in enumerate(iterator):
        	#存储数据集合silver label
            src = src.long().permute(1,0).to(device)
            trg = trg.long().permute(1,0).to(device)
        
            #前向传播求出预测值prediction和隐藏值hidden
            output = model.forward(src, trg) 
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].reshape(5*trg.shape[1])
 			
 			#求loss
            loss = criterion(output, trg)
            #用item取出唯一的元素,求loss的平均值
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

代码分段解析:
此段代码存储了数据集和silver label相应对,与Train()函数中类似,通过前向传播求预测值并求loss平均来定义学习率衰减。

for epoch in range(30):
    #调用train衰减
    train_loss = train(train_loader)
    #调用evaluate衰减
    valid_loss = evaluate(test_loader)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        print('saved')
        #保存训练好的权重
        torch.save(model.state_dict(), 'train.pt')
    
    print(epoch,':')
    print(train_loss, valid_loss)
    print('****************************************\n')

代码分段解析:
在每个epoch训练前分别调用一次train()、evaluate(),实现学习率的衰减。然后使有效损失率达到最优并保存训练好的权重。至此,训练结束。其中,值得一提的是torch.save()函数。它保存一个序列化(serialized)的目标到磁盘。函数使用了Python的pickle程序用于序列化。模型(models),张量(tensors)和文件夹(dictionaries)都是可以用这个函数保存的目标类型。
torch.save(第一个参数,第二个参数):第一个参数是model则保存的是模型;当是model.state_dict()保存的是模型权重。

#保存整个模型:
torch.save(model,'save.pt')
#只保存训练好的权重:
torch.save(model.state_dict(), 'save.pt')

完整参数:

  1. obj:保存对象。
  2. f:要保存的文件名或一个保存文件名的字符串。
  3. pickle_modul:用于 pickling 元数据和对象的模块,可选。
  4. pickle_protocol:指定 pickle protocal 可以覆盖默认参数,可选。

二、总结

本周对Train.py剩余部分进行了分析,重点对torch.no_grad()、torch.save()函数展开了学习。至此,Seq2Seq模型已完成训练。

项目总体分析告一段落。本次课程开展,总共对Extract.py、Model.py、my_dataloader.py、Train.py以及Encoder-Decoder模型、Seq2Seq模型等关键源码等知识进行了学习与分析。理解了张量tensor、TFIDF、embedding similarity嵌入相似性、通过余弦向量计算相关性、LSTM、BiLSTM等等的相关概念,涉猎从机器学习、深度学习到NLP领域的相关知识,从起初接手项目时对它们的零认识零了解到现在熟悉每个概念及其运用,收获颇丰。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值