Dropout原理
-
dropout是指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃。注意是暂时,对于随机梯度下降来说,由于是随机丢弃,故而每一个mini-batch都在训练不同的网络。
-
dropout可以在CNN中防止过拟合
pytorch实现Dropout
dim_in = 28*28
dim_hid = 158
dim_out = 10
class TwoLayerNet(torch.nn.Module):
def __init__(self, dim_in, dim_hid, dim_out):
super(TwoLayerNet, self).__init__()
#define the model architecture
self.fc1 = torch.nn.Linear(dim_in, dim_hid, bias=True)
self.fc2 = torch.nn.Linear(dim_hid, dim_out, bias=True)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = F.relu(x)
x = F.dropout(x, p=0.5)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
#提前定义模型
model = TwoLayerNet(dim_in, dim_hid, dim_out)
Pytorch实现L1、L2正则化
import torch
from torch.nn import functional as F
from torch.autograd import Variable
class MLP(torch.nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.linear1 = torch.nn.Linear(128, 32)
self.linear2 = torch.nn.Linear(32, 16)
self.linear3 = torch.nn.Linear(16, 2)
def forward(self, x):
layer1_out = F.relu(self.linear1(x))
layer2_out = F.relu(self.linear2(lay1_out))
out = self.linear3(lay2_out)
return out, layer1_out, layer2_out
def l1_penalty(var):
return torch.abs(var).sum()
def l2_penalty(var):
return torch.sqrt(torch.pow(var, 2).sum())
bachsize = 4
lambda1, lambda2 = 0.5, 0.01
for i in range(10):
model = MLP()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
#usually following code is looped over all baches
#but let'e just do a dummy batch for brevity
inputs = Variable(torch.rand(batchsize, 128))
targets = Variable(torch.ones(batchsize).long())
optimizer.zero_grad()
outputs, layer1_out, layer2_out = model(inputs)
cross_entropy_loss = F.cross_entropy(outputs, targets)
l1_regularization = lambda1 * l1_penalty(layer1_out)
l2_regularization = lambda2 * l2_penalty(layer2_out)
loss = cross_entropy_loss + l1_regularization + l2_regularization
print(i, loss.item())
loss.backward()
optimizer.step()
- 结果
0 9.15099811553955
1 8.80173397064209
2 10.554376602172852
3 7.918191909790039
4 9.383484840393066
5 8.984872817993164
6 12.230097770690918
7 9.20627498626709
8 8.48149299621582
9 10.598611831665039
[Finished in 0.9s]