1.首先谈的第一个重要的 zero_grad() 梯度清零。
由于每次梯度计算完成后,网络中的优化器梯度不会自动清零,所以需要手动输入函数进行优化器梯度清零。例子:
for i, data in enumerate(dataLoader): # 获取图片和标签 inputs, labels = data inputs, labels = Variable(inputs), Variable(labels) optimizer.zero_grad() #优化器梯度清零 outputs = net(inputs) loss = criterion(outputs, labels) #计算损失函数 loss.backward() 损失函数回传/参数更新 optimizer.step() 优化器更新/参数更新
1
2.谈一下state_dict() 获取模型参数
里面存放一个字典,存放模型的权重、bias偏置,以字典形式返回。例子:
#导入所需要的库函数 import torch import torch.nn as nn import torch.optim as optim import torchvision impo