读书笔记-Dropout


注:本文代码来自李沐书籍,这里只做代码解析

1. Dropout相关概念

Dropout的出现是为了解决模型的过拟合线性,主要应用于全连接层的,在一定程度上起到正则化的作用,提高了模型的鲁棒性。
在这里插入图片描述
我们假设隐藏层中有 n 个 h,我们希望以 概率 p 来丢掉隐藏单元,并且是在训练过程中启用,每次训练时丢弃的 h 时不一样的。具体分布如下:
h ′ = { 0 , 概 率 为 p h 1 − p , 其 他 (1) h'=\left\{ \begin{aligned} 0, 概率为 p \\ \frac{h}{1-p},其他\\ \end{aligned} \right. \tag1 h=0,p1ph,(1)

  • 期望
    E ( h ′ ) = 0 ⋅ p + h 1 − p ⋅ ( 1 − p ) = h (2) E(h')=0·p+\frac{h}{1-p}·(1-p)=h\tag{2} E(h)=0p+1ph(1p)=h(2)
    那么我们发现,在使用了 dropout,整个分布的期望是不变的
    E ( h ′ ) = h = E ( h ) (3) E(h')=h=E(h)\tag{3} E(h)=h=E(h)(3)

2. Dropout 函数定义

# 3. 自定义 dropout函数
def dropout_layer(x, dropout):
	assert 0 <= dropout <= 1  # 判断 dropout 的范围
	if dropout == 1:  # 当 dropout == 1 时,那么返回同样大小的零向量
		return torch.zeros_like(x)
	if dropout == 0:	# 当 dropout == 0 时,那么返回本身
		return x
	mask = (torch.rand(x.shape) > dropout).float() # 定义一个标签张量mask
	return mask * x / (1.0 - dropout)
  • 注:我们为了获得更快的运算速度,我们希望用mask去乘以一个矩阵。

3. Dropout 自定义函数的应用

  • 代码
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: dropout
# @Create time: 2021/11/27 11:26

# 1. 导入数据库
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 2. 定义相关参数
dropout1, dropout2 = 0.2, 0.5
num_inputs, num_outpus, num_hiddens1, num_hiddens2 = 784, 10, 256, 256


# 3. 自定义 dropout函数
def dropout_layer(x, dropout):
	assert 0 <= dropout <= 1  # 判断 dropout 的范围
	if dropout == 1:  # 当 dropout == 1 时,那么返回同样大小的零向量
		return torch.zeros_like(x)
	if dropout == 0:	# 当 dropout == 0 时,那么返回本身
		return x
	mask = (torch.rand(x.shape) > dropout).float() # 定义一个标签张量mask
	return mask * x / (1.0 - dropout)


# 4. 定义模型类
class Net(nn.Module):
	def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2,
				 is_training=True):
		super(Net, self).__init__()
		self.inputs = num_inputs
		self.training = is_training
		self.lin1 = nn.Linear(num_inputs, num_hiddens1)
		self.lin2 = nn.Linear(num_hiddens1, num_hiddens2)
		self.lin3 = nn.Linear(num_hiddens2, num_outputs)
		self.relu = nn.ReLU()

	def forward(self, x): # 定义网络前向传播流
		H1 = self.relu(self.lin1(x.reshape((-1, self.inputs))))
		if self.training == True: # 如果在训练中 ,那么就用 dropout
			H1 = dropout_layer(H1, dropout1)
		H2 = self.relu(self.lin2(H1))
		if self.training == True:
			H2 = dropout_layer(H2, dropout2)
		out = self.lin3(H2)
		return out


# 5. 实例化网络
net = Net(num_inputs, num_outpus, num_hiddens1, num_hiddens2)

# 6. 定义超参数
num_epochs, lr, batch_size = 100, 0.03, 256

# 7. 定义损失函数
loss = nn.CrossEntropyLoss()

# 8. 定义训练集和测试集
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 9. 定义更新器 随机梯度下降 SGD
trainer = torch.optim.SGD(net.parameters(), lr=lr)

# 10. 开始训练
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

# 11. 显示结果
plt.show()
  • 结果

在这里插入图片描述

4. Dropout 在 pytorch中调用

上面是我们自己实现的 dropout_layer 函数, 为了后续的方便使用,我们直接使用 pytorch中自带的 nn.Dropout 模块

  • 代码
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: Dropout_test
# @Create time: 2021/11/27 16:09

# 1. 导入数据库
from torch import nn
import torch
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 2. 定义相关参数
batch_size = 256
dropout1, dropout2 = 0.2, 0.5

# 3. 定义网络模型,运用 nn.Dropout()
net = nn.Sequential(
	nn.Flatten(),
	nn.Linear(784, 256),
	nn.ReLU(),
	nn.Dropout(p=dropout1),
	nn.Linear(256, 256),
	nn.ReLU(),
	nn.Dropout(p=dropout2),
	nn.Linear(256, 10)
)

# 4. 初始化模型参数
def ini_weights(m):
	if type(m) == nn.Linear:
		nn.init.normal_(m.weight, std=0.01)

net.apply(ini_weights)

# 5. 定义超参数
num_epochs, lr, batch_size = 100, 0.03, 256

# 6. 定义优化器
trainer = torch.optim.SGD(net.parameters(), lr=lr)

# 7. 定义训练集和测试集
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

# 8. 定义损失函数
loss = nn.CrossEntropyLoss()

# 9. 开始训练
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

# 10. 显示结果
plt.show()
  • 结果

在这里插入图片描述

5. 小结

用 pytorch 自带的 nn.Dropout 真的很方便。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值