一:背景
1、梯度消失与梯度爆炸
在网络的权重更新中,主要依靠w -= w + n * d(loss)/dw。来进行更新,但是如果深度一些的网络,随着层数的增加,d(loss)/dw =展开的叠乘数目也在增加,在各个过程中,如果设计不合理,出现一些<1或者>1的数,则成指数被的无限趋近于0(或者+∞),则会引起w = w;w = +∞的状态,所谓的梯度消失或者梯度爆炸。
2、解决办法
残差网络的引入就是为了解决上述问题。在2层(或者2层以后)的卷积核输出位置加上输入x,则求导就可以变成dF(x)/dx = (dy(x)/dx +1),来保证就算是y(x)的导数小于1,随着不断地叠乘,最后d(loss)/dw也可以维持在1左右,继续实现梯度的更新。
二:整体思路
三:残差网络模块的设计
class ResNet(torch.nn.Module):
def __init__(self,inchannels):
super(ResNet1, self).__init__()
self.cnn = torch.nn.Conv2d(inchannels,inchannels,3,padding=1)
def forward(self,x):
y = self.cnn(x)
y = self.cnn(F.relu(y))
return x + y
在代码编写的过程中,如果遇到重复较多的结构,应该学会用定义class的办法给封装起来,到时候直接实例化引用。
四:模型设计
#模型设计
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.cnn1 = torch.nn.Conv2d(1,16,5)
self.cnn2 = torch.nn.Conv2d(16,32,5)
self.mp = torch.nn.MaxPool2d(2)
self.line = torch.nn.Linear(512,10)
self.RN1 =ResNet(16)
self.RN2 =ResNet(32)
def forward(self,x):
in_size = x.size(0)
x = self.mp(F.relu(self.cnn1(x)))
x = F.relu(self.RN(x))
x = self.mp(F.relu(self.cnn2(x)))
x = F.relu(self.RN(x))
x = x.view(in_size,-1)
x = self.line(x)
return x
记住x.size(0)的技巧,先读出输出的整体行数这样,用x.view(insize,-1)就可以自动填写了。
五:整体代码
#残差网络的设计
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
#数据集载入
batchsize = 64
transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
train_data = datasets.MNIST("../datasets/",train=True,transform=transforms)
train_data = DataLoader(train_data,batchsize,shuffle=True)
test_data = datasets.MNIST('../datasets/',train=False,transform=transforms)
test_data = DataLoader(test_data,batchsize,shuffle=True)
#残差模块设计
class ResNet(torch.nn.Module):
def __init__(self,inchannels):
super(ResNet, self).__init__()
self.cnn = torch.nn.Conv2d(inchannels,inchannels,3,padding=1)
def forward(self,x):
y = self.cnn(x)
y = self.cnn(F.relu(y))
return x + y
#模型设计
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.cnn1 = torch.nn.Conv2d(1,16,5)
self.cnn2 = torch.nn.Conv2d(16,32,5)
self.mp = torch.nn.MaxPool2d(2)
self.line = torch.nn.Linear(512,10)
self.RN1 =ResNet(16)
self.RN2 =ResNet(32)
def forward(self,x):
in_size = x.size(0)
print(x.size(0))
x = self.mp(F.relu(self.cnn1(x)))
x = F.relu(self.RN1(x))
x = self.mp(F.relu(self.cnn2(x)))
x = F.relu(self.RN2(x))
x = x.view(in_size,-1)
print(x.size(0),x.size(-1))
x = self.line(x)
return x
model = Model()
MSE = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(),lr = 0.01)
#训练函数编写
def train(epoch):
total = 0
for i,data in enumerate(train_data,0):
images,index = data
y = model(images)
loss = MSE(y,index)
total += loss.item()
opt.zero_grad()
loss.backward()
opt.step()
if i % 300 == 299 :
print('epoch = %d, i = %d' % (epoch,i))
print('loss = ', total / 300)
total = 0
#定义测试函数
def test():
total = 0
correct = 0
with torch.no_grad():
for i,data in enumerate(test_data,0):
images,index = data
y =model(images)
_,a = torch.max(y,1)
correct += (a == index).sum().item()
total += index.size(0)
print('正确率 = ', correct / total)
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()