文章目录
6.1 自定义损失函数
6.1.1 以函数方式定义
# 以函数形式定义损失函数
def loss_func(output, target):
# 均方误差
loss = torch.mean((output-target)**2)
return loss
6.1.2 以类的方式定义
虽然以函数定义的方式很简单,但是以类方式定义更加常用,在以类方式定义损失函数时,我们如果看每一个损失函数的继承关系我们就可以发现Loss函数部分继承自_loss
, 部分继承自_WeightedLoss
, 而_WeightedLoss
继承自_loss
, _loss
继承自 nn.Module
。我们可以将其当作神经网络的一层来对待,同样地,我们的损失函数类就需要继承自nn.Module
类,在下面的例子中我们以DiceLoss
为例向大家讲述。
Dice Loss是一种在分割领域常见的损失函数,定义如下:
D S C = 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ DSC = \frac{2|X∩Y|}{|X|+|Y|} DSC=∣X∣+∣Y∣2∣X∩Y∣
class DiceLoss(nn.Module):
def __init(self, weight=None, size_average=True):
super(DiceLoss, self).__init__()
def forward(self, inputs, targets, smooth=1):
inputs = F.sigmoid(inputs)
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
return 1 - dice
在使用这个巡视函数的时候只需要如下即可:
criterion = DiceLoss()
loss = criterion(input, targets)
在自定义损失函数时,涉及到数学运算时,我们最好全程使用PyTorch提供的张量计算接口,这样就不需要我们实现自动求导功能并且我们可以直接调用cuda,使用numpy或者scipy的数学运算时,操作会有些麻烦,大家可以自己下去进行探索。关于PyTorch使用Class定义损失函数的原因,可以参考PyTorch的讨论区
其实在pytorch
中很多损失函数都是定义好了的
class _Loss(Module):
@weak_module
class L1Loss(_Loss):
@weak_module
class NLLLoss(_WeightedLoss):
@weak_module
class NLLLoss2d(NLLLoss):
@weak_module
class PoissonNLLLoss(_Loss):
@weak_module
class KLDivLoss(_Loss):
@weak_module
class MSELoss(_Loss):
@weak_module
class BCELoss(_WeightedLoss):
@weak_module
class BCEWithLogitsLoss(_Loss):
@weak_module
class HingeEmbeddingLoss(_Loss):
@weak_module
class MultiLabelMarginLoss(_Loss):
@weak_module
class SmoothL1Loss(_Loss):
# ————————————————
# 版权声明:本文为CSDN博主「LoveMIss-Y」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
# 原文链接:https://blog.csdn.net/qq_27825451/article/details/95165265
所有的类都继承自_Loss
类,而——Loss
类又继承自Module
类
6.1.3 继承nn.autograd.Function
来实现
Function类和Module类最大的区别是Function类多了一个backword()方法
# 定义一个继承了Function类的子类,实现y=f(x)的正向运算以及反向求导
class sqrt_and_inverse(torch.autograd.Function):
'''
本例子所采用的数学公式是:
z=sqrt(x)+1/x+2*power(y,2)
z是关于x,y的一个二元函数它的导数是
z'(x)=1/(2*sqrt(x))-1/power(x,2)
z'(y)=4*y
forward和backward可以定义成静态方法,向定义中那样,也可以定义成实例方法
'''
# 前向运算
def forward(self, input_x, input_y):
'''
self.save_for_backward(input_x,input_y) ,这个函数是定义在Function的父类_ContextMethodMixin中
它是将函数的输入参数保存起来以便后面在求导时候再使用,起前向反向传播中协调作用
'''
self.save_for_backward(input_x, input_y)
# 对输入和参数进行的操作,其实就是前向运算的函数表达式]
output = torch.sqrt(input_x) + torch.reciprocal(input_x) + 2 * torch.pow(input_y, 2)
return output
def backward(self, grad_output):
# 计算梯度是链式法则,输入的参数grad_output为反向传播上一级计算得到的梯度值
input_x, input_y = self.saved_tensors # 获取前面保存的参数,也可以使用self.saved_variables
# 求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式
# 这里上一级梯度值grad_output乘以当前级的梯度
grad_x = grad_output * (torch.reciprocal(2 * torch.sqrt(input_x)) - torch.reciprocal(torch.pow(input_x, 2)))
grad_y = grad_output * (4 * input_y)
return grad_x, grad_y # 需要注意的是,反向传播得到的结果需要与输入的参数相匹配
# ————————————————
# 版权声明:本文为CSDN博主「豆豆小朋友小笔记」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
# 原文链接:https://blog.csdn.net/qq_40728805/article/details/103906140
6.1.4 案例
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 第一步:自定义损失函数
# 继承nn.Mdule
class My_loss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return torch.mean(torch.pow((x - y), 2))
# 第二步:准备数据集,模拟一个线性拟合过程
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
[9.779], [6.182], [7.59], [2.167], [7.042],
[10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
[3.366], [2.596], [2.53], [1.221], [2.827],
[3.465], [1.65], [2.904], [1.3]], dtype=np.float32)
# 将numpy数据转化为torch的张量
inputs = torch.from_numpy(x_train)
targets = torch.from_numpy(y_train)
input_size = 1
output_size = 1
num_epochs = 60
learning_rate = 0.001
# 第三步: 构建模型,构建一个一层的网络模型
model = nn.Linear(input_size, output_size)
# 与模型相关的配置、损失函数、优化方式
# 使用自定义函数,等价于criterion = nn.MSELoss()
criterion = My_loss()
# 定义迭代优化算法, 使用的是随机梯度下降算法
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
loss_history = []
# 第四步:训练模型,迭代训练
for epoch in range(num_epochs):
# 前向传播计算网络结构的输出结果
outputs = model(inputs)
# 计算损失函数
loss = criterion(outputs, targets)
# 反向传播更新参数,三步策略,归零梯度——>反向传播——>更新参数
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印训练信息和保存loss
loss_history.append(loss.item())
if (epoch+1) % 5 == 0:
print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
Epoch [5/60], Loss: 2.8296
Epoch [10/60], Loss: 1.3841
Epoch [15/60], Loss: 0.7981
Epoch [20/60], Loss: 0.5603
Epoch [25/60], Loss: 0.4637
Epoch [30/60], Loss: 0.4242
Epoch [35/60], Loss: 0.4078
Epoch [40/60], Loss: 0.4008
Epoch [45/60], Loss: 0.3977
Epoch [50/60], Loss: 0.3960
Epoch [55/60], Loss: 0.3950
Epoch [60/60], Loss: 0.3943
# 第五步:结果展示。画出原y与x的曲线与网络结构拟合后的曲线
predicted = model(torch.from_numpy(x_train)).detach().numpy() #模型输出结果
plt.plot(x_train, y_train, 'ro', label='Original data') #原始数据
plt.plot(x_train, predicted, label='Fitted line') #拟合之后的直线
plt.legend()
plt.show()
# 画loss在迭代过程中的变化情况
plt.plot(loss_history, label='loss for every epoch')
plt.legend()
plt.show()
6.2 动态调整学习率
6.2.1 官方的scheduler
PyTorch在torch.optim.lr_scheduler
中为我们封装好了一些动态调整学习率的方法:
lr_scheduler.LambdaLR
lr_scheduler.MultiplicativeLR
lr_scheduler.StepLR
lr_scheduler.MultiStepLR
lr_scheduler.ExponentialLR
lr_scheduler.CosineAnnealingLR
lr_scheduler.ReduceLROnPlateau
lr_scheduler.CyclicLR
lr_scheduler.OneCycleLR
lr_scheduler.CosineAnnealingWarmRestarts
PyTorch的官方代码:
# 选择一种优化器
optimizer = torch.optim.Adam(...)
# 选择上面提到的一种或多种动态调整学习率的方法
scheduler1 = torch.optim.lr_scheduler....
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
# 进行训练
for epoch in range(100):
train(...)
validate(...)
optimizer.step()
# 需要在优化器参数更新之后再动态调整学习率
scheduler1.step()
...
schedulern.step()
6.2.2 自定义scheduler
自定义函数adjust_learning_rate来改变param_group中lr的值
# 自定义学习率
# 学习率每39轮下降为原来的1/10
def adjust_learning_rate(optimizer, epoch):
lr = args.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 下面是训练脚本
optimizer = torch.optim.SGD(model.parameters(),lr = args.lr,momentum = 0.9)
for epoch in range(10):
train(...)
val(...)
adjust_learning_rate(optimizer,epoch)
6.3 模型微调
6.3.1 模型微调流程
-
在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
-
创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
-
为目标模型添加一个输出⼤小为⽬标数据集类别个数的输出层,并随机初始化该层的模型参数。
-
在目标数据集上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。
6.3.1 模型微调-torchvision
# 通过True或者False来决定是否使用预训练好的权重,
# 在默认状态下pretrained = False,意味着我们不使用预训练得到的权重,
# 当pretrained = True,意味着我们将使用在一些数据集上预训练得到的权重。
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:209: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
0%| | 0.00/44.7M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
0%| | 0.00/233M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=SqueezeNet1_0_Weights.IMAGENET1K_V1`. You can also use `weights=SqueezeNet1_0_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth" to /root/.cache/torch/hub/checkpoints/squeezenet1_0-b66bff10.pth
0%| | 0.00/4.78M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
0%| | 0.00/528M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=DenseNet161_Weights.IMAGENET1K_V1`. You can also use `weights=DenseNet161_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/densenet161-8d451a50.pth" to /root/.cache/torch/hub/checkpoints/densenet161-8d451a50.pth
0%| | 0.00/110M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=Inception_V3_Weights.IMAGENET1K_V1`. You can also use `weights=Inception_V3_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth
0%| | 0.00/104M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=GoogLeNet_Weights.IMAGENET1K_V1`. You can also use `weights=GoogLeNet_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/googlenet-1378be20.pth" to /root/.cache/torch/hub/checkpoints/googlenet-1378be20.pth
0%| | 0.00/49.7M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1`. You can also use `weights=ShuffleNet_V2_X1_0_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth" to /root/.cache/torch/hub/checkpoints/shufflenetv2_x1-5666bf0f80.pth
0%| | 0.00/8.79M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MobileNet_V2_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V2_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
0%| | 0.00/13.6M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MobileNet_V3_Large_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V3_Large_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-8738ca79.pth
0%| | 0.00/21.1M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1`. You can also use `weights=MobileNet_V3_Small_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_small-047dcff4.pth
0%| | 0.00/9.83M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNeXt50_32X4D_Weights.IMAGENET1K_V1`. You can also use `weights=ResNeXt50_32X4D_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" to /root/.cache/torch/hub/checkpoints/resnext50_32x4d-7cdf4587.pth
0%| | 0.00/95.8M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=Wide_ResNet50_2_Weights.IMAGENET1K_V1`. You can also use `weights=Wide_ResNet50_2_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" to /root/.cache/torch/hub/checkpoints/wide_resnet50_2-95faca4d.pth
0%| | 0.00/132M [00:00<?, ?B/s]
/usr/local/lib/python3.7/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=MNASNet1_0_Weights.IMAGENET1K_V1`. You can also use `weights=MNASNet1_0_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth" to /root/.cache/torch/hub/checkpoints/mnasnet1.0_top1_73.512-f206786ef8.pth
0%| | 0.00/16.9M [00:00<?, ?B/s]
6.3.2 模型微调-timm
6.4 半精度训练
通过抠浮点数存储位数来减少计算开销
# 设置方法
from torch.cuda.amp import autocast
# 模型设置
@autocast()
def forward(self, x):
...
return x
# 训练过程
for x in train_loader:
x = x.cuda()
# 后面的都是autocast
with autocast():
output = model(x)
...
6.5 数据增强
6.5.1 imgaug简介与安装
imgaug是一个数据增强库
安装方法:
conda config --add channels conda-forge
conda install imgaug
6.5.2 imgaug使用
6.5.2.1 单图处理
import imageio
import imgaug as ia
%matplotlib inline
# 图片的读取
img = imageio.imread("/content/Lenna.png")
# 使用Image进行读取
# img = Image.open("./Lenna.jpg")
# image = np.array(img)
# ia.imshow(image)
# 可视化图片
ia.imshow(img)
from imgaug import augmenters as iaa
# 设置随机数种子
ia.seed(4)
# 实例化方法
rotate = iaa.Affine(rotate=(-4,45))
img_aug = rotate(image=img)
ia.imshow(img_aug)
6.5.3 PyTorch与imgaug
6.6 使用argparse调参
总的来说,我们可以将argparse的使用归纳为以下三个步骤。
-
创建
ArgumentParser()
对象 -
调用
add_argument()
方法添加参数 -
使用
parse_args()
解析参数 在接下来的内容中,我们将以实际操作来学习argparse的使用方法。
# demo.py
import argparse
# 创建ArgumentParser()对象
parser = argparse.ArgumentParser()
# 添加参数
parser.add_argument('-o', '--output', action='store_true',
help="shows output")
# action = `store_true` 会将output参数记录为True
# type 规定了参数的格式
# default 规定了默认值
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')
parser.add_argument('--batch_size', type=int, required=True, help='input batch size')
# 使用parse_args()解析函数
args = parser.parse_args()
if args.output:
print("This is some output")
print(f"learning rate:{args.lr} ")
我们在命令行使用python demo.py --lr 3e-4 --batch_size 32,就可以看到以下的输出
This is some output
learning rate: 3e-4