深度学习过拟合解决方案

本文转自:https://blog.csdn.net/zhang2010hao/article/details/89339327

1.29.深度学习过拟合解决方案

1.29.1.解决方案

对于深度学习网络的过拟合,一般的解决方案有:
Early stop
在模型训练过程中,提前终止。这里可以根据具体指标设置early stop的条件,比如可以是loss的大小,或者acc/f1等值的epoch之间的大小对比。

More data
更多的数据集。增加样本也是一种解决方案,根据不同场景和数据不同的数据增强方法。

正则化
常用的有L1,L2正则化

Droup Out
以一定的概率使某些神经元停止工作

BatchNorm
对神经元作归一化

1.29.2.实现

这里主要讲述一下在pytorch中的过拟合解决方案,early stop和more data都是对于特定的任务去进行的,不同的任务有不同的解决方案,这里不做进一步说明。在pytorch框架下后面几种解决方案是有统一的结构或者解决办法的,这里一一道来。

正则化
torch.optim集成了很多优化器,如SGD,Adadelta,Adam,Adagrad,RMSprop等,这些优化器中有一个参数weight_decay,用于指定权值衰减率,相当于L2正则化中的λ参数,注意torch.optim集成的优化器只有L2正则化方法,api中参数weight_decay 的解析是:weight_decay (float, optional): weight decay (L2 penalty) (default: 0),这里可以看出其weight_decay就是正则化项的作用。可以如下设置L2正则化:

optimizer = optim.Adam(model.parameters(),lr=0.001,weight_decay=0.01)

但是这种方法存在几个问题:
(1)一般正则化,只是对模型的权重W参数进行惩罚,而偏置参数b是不进行惩罚的,而torch.optim的优化器weight_decay参数指定的权值衰减是对网络中的所有参数,包括权值w和偏置b同时进行惩罚。很多时候如果对b 进行L2正则化将会导致严重的欠拟合,因此这个时候一般只需要对权值w进行正则即可。(PS:这个我真不确定,源码解析是 weight decay (L2 penalty) ,但有些网友说这种方法会对参数偏置b也进行惩罚,可解惑的网友给个明确的答复)
(2)缺点:torch.optim的优化器只能实现L2正则化,不能实现L1正则化。
(3)根据正则化的公式,加入正则化后,loss会变原来大,比如weight_decay=1的loss为10,那么weight_decay=100时,loss输出应该也提高100倍左右。而采用torch.optim的优化器的方法,如果你依然采用loss_fun= nn.CrossEntropyLoss()进行计算loss,你会发现,不管你怎么改变weight_decay的大小,loss会跟之前没有加正则化的大小差不多。这是因为你的loss_fun损失函数没有把权重W的损失加上。

1.29.3.Drop out实现

pytorch中有两种方式可以实现dropout
1)使用nn.Dropout类,先初始化掉该类,然后可以在后面直接调用

# -*- coding: UTF-8 -*-

import torch.nn as nn

class Exmp(nn.Module):
    def __init__(drop_rate):
        self.dropout = nn.Dropout(drop_rate)
        ...

    def forward():
        ...
        output = self.dropout(input)
        ...

2)使用torch.nn.functional.dropout函数实现dropout

# -*- coding: UTF-8 -*-

import torch.nn as nn

class Exmp(nn.Module):
    def __init__(drop_rate):
        self.dropout = nn.Dropout(drop_rate)
        ...

    def forward():
        ...
        output = self.dropout(input)
        ...

上面只有一种示例,在实际使用中第二种更加灵活,可以在不同的层之间使用不同的drop_rate, 第一种的好处是可以一次初始化后面每次dropout保持一致。

1.29.4.BatchNorm

批标准化通俗来说就是对每一层神经网络进行标准化 (normalize) 处理,具体的原理我再次不做赘述,网上资料很多。

pytorch中BatchNorm有BatchNorm1d、BatchNorm2d、BatchNorm3d三种,根据具体数据选择不同的BatchNorm,BatchNorm层的使用与普通的层使用方法类似。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值