【深入浅出PyTorch】6-进阶训练技巧

6.1 自定义损失函数

6.1.1 以函数方式定义

# 以函数形式定义损失函数
def loss_func(output, target):
  # 均方误差
  loss = torch.mean((output-target)**2)
  return loss

6.1.2 以类的方式定义

虽然以函数定义的方式很简单,但是以类方式定义更加常用,在以类方式定义损失函数时,我们如果看每一个损失函数的继承关系我们就可以发现Loss函数部分继承自_loss, 部分继承自_WeightedLoss, 而_WeightedLoss继承自_loss_loss继承自 nn.Module。我们可以将其当作神经网络的一层来对待,同样地,我们的损失函数类就需要继承自nn.Module类,在下面的例子中我们以DiceLoss为例向大家讲述。

Dice Loss是一种在分割领域常见的损失函数,定义如下:

D S C = 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ DSC = \frac{2|X∩Y|}{|X|+|Y|} DSC=X+Y2∣XY

class DiceLoss(nn.Module):
  def __init(self, weight=None, size_average=True):
    super(DiceLoss, self).__init__()

  def forward(self, inputs, targets, smooth=1):    
        inputs = F.sigmoid(inputs)       
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()                   
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        return 1 - dice

在使用这个巡视函数的时候只需要如下即可:

criterion = DiceLoss()
loss = criterion(input, targets)

在自定义损失函数时,涉及到数学运算时,我们最好全程使用PyTorch提供的张量计算接口,这样就不需要我们实现自动求导功能并且我们可以直接调用cuda,使用numpy或者scipy的数学运算时,操作会有些麻烦,大家可以自己下去进行探索。关于PyTorch使用Class定义损失函数的原因,可以参考PyTorch的讨论区

其实在pytorch中很多损失函数都是定义好了的


class _Loss(Module):
 
@weak_module
class L1Loss(_Loss):
 
@weak_module
class NLLLoss(_WeightedLoss):
 
@weak_module
class NLLLoss2d(NLLLoss):
 
@weak_module
class PoissonNLLLoss(_Loss):
 
@weak_module
class KLDivLoss(_Loss):
 
@weak_module
class MSELoss(_Loss):
 
@weak_module
class BCELoss(_WeightedLoss):
 
@weak_module
class BCEWithLogitsLoss(_Loss):
 
@weak_module
class HingeEmbeddingLoss(_Loss):
 
@weak_module
class MultiLabelMarginLoss(_Loss):
 
@weak_module
class SmoothL1Loss(_Loss):
# ————————————————
# 版权声明:本文为CSDN博主「LoveMIss-Y」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
# 原文链接:https://blog.csdn.net/qq_27825451/article/details/95165265

所有的类都继承自_Loss类,而——Loss类又继承自Module

6.1.3 继承nn.autograd.Function来实现

Function类和Module类最大的区别是Function类多了一个backword()方法

# 定义一个继承了Function类的子类,实现y=f(x)的正向运算以及反向求导
class sqrt_and_inverse(torch.autograd.Function):
    '''
    本例子所采用的数学公式是:
    z=sqrt(x)+1/x+2*power(y,2)
    z是关于x,y的一个二元函数它的导数是
    z'(x)=1/(2*sqrt(x))-1/power(x,2)
    z'(y)=4*y
    forward和backward可以定义成静态方法,向定义中那样,也可以定义成实例方法
    '''
    # 前向运算
    def forward(self, input_x, input_y):
        '''
        self.save_for_backward(input_x,input_y) ,这个函数是定义在Function的父类_ContextMethodMixin中
             它是将函数的输入参数保存起来以便后面在求导时候再使用,起前向反向传播中协调作用
        '''
        self.save_for_backward(input_x, input_y)
        # 对输入和参数进行的操作,其实就是前向运算的函数表达式]
        output = torch.sqrt(input_x) + torch.reciprocal(input_x) + 2 * torch.pow(input_y, 2)
        return output

    def backward(self, grad_output):
        # 计算梯度是链式法则,输入的参数grad_output为反向传播上一级计算得到的梯度值
        input_x, input_y = self.saved_tensors  # 获取前面保存的参数,也可以使用self.saved_variables
        # 求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式
        # 这里上一级梯度值grad_output乘以当前级的梯度
        grad_x = grad_output * (torch.reciprocal(2 * torch.sqrt(input_x)) - torch.reciprocal(torch.pow(input_x, 2)))
        grad_y = grad_output * (4 * input_y)
        return grad_x, grad_y  # 需要注意的是,反向传播得到的结果需要与输入的参数相匹配
    
# ————————————————
# 版权声明:本文为CSDN博主「豆豆小朋友小笔记」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
# 原文链接:https://blog.csdn.net/qq_40728805/article/details/103906140

6.1.4 案例

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 第一步:自定义损失函数
 
# 继承nn.Mdule
class My_loss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x, y):
        return torch.mean(torch.pow((x - y), 2))
# 第二步:准备数据集,模拟一个线性拟合过程
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168], 
                    [9.779], [6.182], [7.59], [2.167], [7.042], 
                    [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
 
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573], 
                    [3.366], [2.596], [2.53], [1.221], [2.827], 
                    [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)
 
# 将numpy数据转化为torch的张量
inputs = torch.from_numpy(x_train)
targets = torch.from_numpy(y_train)
input_size = 1
output_size = 1
num_epochs = 60
learning_rate = 0.001
 
# 第三步: 构建模型,构建一个一层的网络模型
model = nn.Linear(input_size, output_size)
 
# 与模型相关的配置、损失函数、优化方式
# 使用自定义函数,等价于criterion = nn.MSELoss()
criterion = My_loss()
 
# 定义迭代优化算法, 使用的是随机梯度下降算法
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  
loss_history = []
# 第四步:训练模型,迭代训练
for epoch in range(num_epochs):
    #  前向传播计算网络结构的输出结果
    outputs = model(inputs)
 
    # 计算损失函数
    loss = criterion(outputs, targets)
    
    # 反向传播更新参数,三步策略,归零梯度——>反向传播——>更新参数
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
 
    # 打印训练信息和保存loss
    loss_history.append(loss.item()) 
    if (epoch+1) % 5 == 0:
        print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

Epoch [5/60], Loss: 2.8296
Epoch [10/60], Loss: 1.3841
Epoch [15/60], Loss: 0.7981
Epoch [20/60], Loss: 0.5603
Epoch [25/60], Loss: 0.4637
Epoch [30/60], Loss: 0.4242
Epoch [35/60], Loss: 0.4078
Epoch [40/60], Loss: 0.4008
Epoch [45/60], Loss: 0.3977
Epoch [50/60], Loss: 0.3960
Epoch [55/60], Loss: 0.3950
Epoch [60/60], Loss: 0.3943
# 第五步:结果展示。画出原y与x的曲线与网络结构拟合后的曲线
predicted = model(torch.from_numpy(x_train)).detach().numpy() #模型输出结果
 
plt.plot(x_train, y_train, 'ro', label='Original data')       #原始数据
plt.plot(x_train, predicted, label='Fitted line')             #拟合之后的直线
plt.legend()
plt.show()
 
# 画loss在迭代过程中的变化情况
plt.plot(loss_history, label='loss for every epoch')
plt.legend()
plt.show()

output_18_0

output_18_1

6.2 动态调整学习率

6.2.1 官方的scheduler

PyTorch在torch.optim.lr_scheduler中为我们封装好了一些动态调整学习率的方法:

PyTorch的官方代码:

# 选择一种优化器
optimizer = torch.optim.Adam(...) 
# 选择上面提到的一种或多种动态调整学习率的方法
scheduler1 = torch.optim.lr_scheduler.... 
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
# 进行训练
for epoch in range(100):
    train(...)
    validate(...)
    optimizer.step()
    # 需要在优化器参数更新之后再动态调整学习率
	scheduler1.step() 
	...
    schedulern.step()

6.2.2 自定义scheduler

自定义函数adjust_learning_rate来改变param_group中lr的值

# 自定义学习率
# 学习率每39轮下降为原来的1/10
def adjust_learning_rate(optimizer, epoch):
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
# 下面是训练脚本
optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
for epoch in range(10):
    train(...)
    val(...)
    adjust_learning_rate(optimizer,epoch)

6.3 模型微调

6.3.1 模型微调流程

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型

  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。

  3. 为目标模型添加一个输出⼤小为⽬标数据集类别个数的输出层,并随机初始化该层的模型参数。

  4. 在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

6.3.1 模型微调-torchvision

# 通过True或者False来决定是否使用预训练好的权重,
# 在默认状态下pretrained = False,意味着我们不使用预训练得到的权重,
# 当pretrained = True,意味着我们将使用在一些数据集上预训练得到的权重。

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth



  0%|          | 0.00/44.7M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth



  0%|          | 0.00/233M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=SqueezeNet1_0_Weights.IMAGENET1K_V1`. You can also use `weights=SqueezeNet1_0_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth" to /root/.cache/torch/hub/checkpoints/squeezenet1_0-b66bff10.pth



  0%|          | 0.00/4.78M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth



  0%|          | 0.00/528M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=DenseNet161_Weights.IMAGENET1K_V1`. You can also use `weights=DenseNet161_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/densenet161-8d451a50.pth" to /root/.cache/torch/hub/checkpoints/densenet161-8d451a50.pth



  0%|          | 0.00/110M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=Inception_V3_Weights.IMAGENET1K_V1`. You can also use `weights=Inception_V3_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth



  0%|          | 0.00/104M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=GoogLeNet_Weights.IMAGENET1K_V1`. You can also use `weights=GoogLeNet_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/googlenet-1378be20.pth" to /root/.cache/torch/hub/checkpoints/googlenet-1378be20.pth



  0%|          | 0.00/49.7M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1`. You can also use `weights=ShuffleNet_V2_X1_0_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth" to /root/.cache/torch/hub/checkpoints/shufflenetv2_x1-5666bf0f80.pth



  0%|          | 0.00/8.79M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MobileNet_V2_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V2_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth



  0%|          | 0.00/13.6M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MobileNet_V3_Large_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V3_Large_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth



  0%|          | 0.00/21.1M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V3_Small_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_small-047dcff4.pth



  0%|          | 0.00/9.83M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNeXt50_32X4D_Weights.IMAGENET1K_V1`. You can also use `weights=ResNeXt50_32X4D_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" to /root/.cache/torch/hub/checkpoints/resnext50_32x4d-7cdf4587.pth



  0%|          | 0.00/95.8M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=Wide_ResNet50_2_Weights.IMAGENET1K_V1`. You can also use `weights=Wide_ResNet50_2_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" to /root/.cache/torch/hub/checkpoints/wide_resnet50_2-95faca4d.pth



  0%|          | 0.00/132M [00:00<?, ?B/s]


/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MNASNet1_0_Weights.IMAGENET1K_V1`. You can also use `weights=MNASNet1_0_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth" to /root/.cache/torch/hub/checkpoints/mnasnet1.0_top1_73.512-f206786ef8.pth



  0%|          | 0.00/16.9M [00:00<?, ?B/s]

6.3.2 模型微调-timm

6.4 半精度训练

通过抠浮点数存储位数来减少计算开销

# 设置方法
from torch.cuda.amp import autocast

# 模型设置
@autocast()   
def forward(self, x):
    ...
    return x

# 训练过程
 for x in train_loader:
	x = x.cuda()
  # 后面的都是autocast
	with autocast():
        output = model(x)
        ...

6.5 数据增强

6.5.1 imgaug简介与安装

imgaug是一个数据增强库

  1. Github地址:imgaug
  2. Readthedocs:imgaug
  3. 官方提供notebook例程:notebook

安装方法:

conda config --add channels conda-forge
conda install imgaug

6.5.2 imgaug使用

6.5.2.1 单图处理

import imageio
import imgaug as ia
%matplotlib inline

# 图片的读取
img = imageio.imread("/content/Lenna.png")

# 使用Image进行读取
# img = Image.open("./Lenna.jpg")
# image = np.array(img)
# ia.imshow(image)

# 可视化图片
ia.imshow(img)

output_39_0

from imgaug import augmenters as iaa

# 设置随机数种子
ia.seed(4)

# 实例化方法
rotate = iaa.Affine(rotate=(-4,45))
img_aug = rotate(image=img)
ia.imshow(img_aug)


output_40_0

6.5.3 PyTorch与imgaug

6.6 使用argparse调参

总的来说,我们可以将argparse的使用归纳为以下三个步骤。

  • 创建ArgumentParser()对象

  • 调用add_argument()方法添加参数

  • 使用parse_args()解析参数 在接下来的内容中,我们将以实际操作来学习argparse的使用方法。

# demo.py
import argparse

# 创建ArgumentParser()对象
parser = argparse.ArgumentParser()

# 添加参数
parser.add_argument('-o', '--output', action='store_true', 
    help="shows output")
# action = `store_true` 会将output参数记录为True
# type 规定了参数的格式
# default 规定了默认值
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3') 

parser.add_argument('--batch_size', type=int, required=True, help='input batch size')  
# 使用parse_args()解析函数
args = parser.parse_args()

if args.output:
    print("This is some output")
    print(f"learning rate:{args.lr} ")

我们在命令行使用python demo.py --lr 3e-4 --batch_size 32,就可以看到以下的输出

This is some output
learning rate: 3e-4

PyTorch模型定义与进阶训练技巧

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值