Pytorch训练技巧
1.自定义损失函数
pytorch在torch.nn模块里提供了很多损失函数,如MSELoss,L1Loss等,但同时也可以自定义损失函数:
两种定义损失函数方法:
-
以函数方式
直接定义一个函数即可
def my_loss(output, target): loss = torch.mean((output - target)**2) return loss
-
以类方式
损失函数继承自_Loss类,_WeightedLoss类,这两个都继承了nn.Module,故自定义类损失应该继承nn.Module。列如
class DiceLoss(nn.Module): def __init__(self,weight=None,size_average=True): super(DiceLoss,self).__init__() def forward(self,inputs,targets,smooth=1): inputs = F.sigmoid(inputs) inputs = inputs.view(-1) targets = targets.view(-1) intersection = (inputs * targets).sum() dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) return 1 - dice # 使用方法 criterion = DiceLoss() loss = criterion(input,targets)
2.动态调整学习率
-
可使用官方api
# 选择一种优化器 optimizer = torch.optim.Adam(...) # 选择上面提到的一种或多种动态调整学习率的方法 scheduler1 = torch.optim.lr_scheduler.... scheduler2 = torch.optim.lr_scheduler.... ... schedulern = torch.optim.lr_scheduler.... # 进行训练 for epoch in range(100): train(...) validate(...) optimizer.step() # 需要在优化器参数更新之后再动态调整学习率 # scheduler的优化是在每一轮后面进行的 scheduler1.step() ... schedulern.step()
-
也可以自定义scheduler
自定义函数
adjust_learning_rate
来改变param_group
中lr
的值。def adjust_learning_rate(optimizer, epoch): lr = args.lr * (0.1 ** (epoch // 30)) for param_group in optimizer.param_groups: param_group['lr'] = lr
def adjust_learning_rate(optimizer,...): ... optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9) for epoch in range(10): train(...) validate(...) adjust_learning_rate(optimizer,epoch)
3.模型微调
通过timm.create_model()
的方法来进行模型的创建,我们可以通过传入参数pretrained=True
,来使用预训练模型。
import timm
import torch
model = timm.create_model('resnet34',pretrained=True)
x = torch.randn(1,3,224,224)
output = model(x)
output.shape
4.半精度训练
将默认的单精度浮点数torch.float32改为torch.float16,可以节约运行内存