torch学习笔记——ResNet残差网络

一:背景

        1、梯度消失与梯度爆炸

        在网络的权重更新中,主要依靠w -= w + n * d(loss)/dw。来进行更新,但是如果深度一些的网络,随着层数的增加,d(loss)/dw =展开的叠乘数目也在增加,在各个过程中,如果设计不合理,出现一些<1或者>1的数,则成指数被的无限趋近于0(或者+∞),则会引起w = w;w = +∞的状态,所谓的梯度消失或者梯度爆炸。

        2、解决办法

        残差网络的引入就是为了解决上述问题。在2层(或者2层以后)的卷积核输出位置加上输入x,则求导就可以变成dF(x)/dx = (dy(x)/dx +1),来保证就算是y(x)的导数小于1,随着不断地叠乘,最后d(loss)/dw也可以维持在1左右,继续实现梯度的更新。

二:整体思路

 

        

三:残差网络模块的设计

class ResNet(torch.nn.Module):
    def __init__(self,inchannels):
        super(ResNet1, self).__init__()
        self.cnn = torch.nn.Conv2d(inchannels,inchannels,3,padding=1)
    def forward(self,x):
        y = self.cnn(x)
        y = self.cnn(F.relu(y))
        return x + y

        在代码编写的过程中,如果遇到重复较多的结构,应该学会用定义class的办法给封装起来,到时候直接实例化引用。

四:模型设计

        

#模型设计
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.cnn1 = torch.nn.Conv2d(1,16,5)
        self.cnn2 = torch.nn.Conv2d(16,32,5)
        self.mp = torch.nn.MaxPool2d(2)
        self.line = torch.nn.Linear(512,10)
        self.RN1 =ResNet(16)
        self.RN2 =ResNet(32)


    def forward(self,x):
        in_size = x.size(0)
        x = self.mp(F.relu(self.cnn1(x)))
        x = F.relu(self.RN(x))
        x = self.mp(F.relu(self.cnn2(x)))
        x = F.relu(self.RN(x))
        x = x.view(in_size,-1)
        x = self.line(x)
        return x

        记住x.size(0)的技巧,先读出输出的整体行数这样,用x.view(insize,-1)就可以自动填写了。

五:整体代码

#残差网络的设计
import torch
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader



#数据集载入
batchsize = 64
transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
train_data = datasets.MNIST("../datasets/",train=True,transform=transforms)
train_data = DataLoader(train_data,batchsize,shuffle=True)
test_data = datasets.MNIST('../datasets/',train=False,transform=transforms)
test_data = DataLoader(test_data,batchsize,shuffle=True)



#残差模块设计

class ResNet(torch.nn.Module):
    def __init__(self,inchannels):
        super(ResNet, self).__init__()
        self.cnn = torch.nn.Conv2d(inchannels,inchannels,3,padding=1)
    def forward(self,x):
        y = self.cnn(x)
        y = self.cnn(F.relu(y))
        return x + y

#模型设计
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.cnn1 = torch.nn.Conv2d(1,16,5)
        self.cnn2 = torch.nn.Conv2d(16,32,5)
        self.mp = torch.nn.MaxPool2d(2)
        self.line = torch.nn.Linear(512,10)
        self.RN1 =ResNet(16)
        self.RN2 =ResNet(32)


    def forward(self,x):
        in_size = x.size(0)
        print(x.size(0))
        x = self.mp(F.relu(self.cnn1(x)))
        x = F.relu(self.RN1(x))
        x = self.mp(F.relu(self.cnn2(x)))
        x = F.relu(self.RN2(x))
        x = x.view(in_size,-1)
        print(x.size(0),x.size(-1))
        x = self.line(x)
        return x

model = Model()
MSE = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(),lr = 0.01)

#训练函数编写

def train(epoch):
    total = 0
    for i,data in enumerate(train_data,0):
        images,index = data
        y = model(images)
        loss = MSE(y,index)
        total += loss.item()
        opt.zero_grad()
        loss.backward()
        opt.step()
        if i % 300 == 299 :
            print('epoch = %d, i = %d' % (epoch,i))
            print('loss = ', total / 300)
            total = 0



#定义测试函数

def test():
    total = 0
    correct = 0
    with torch.no_grad():
        for i,data in enumerate(test_data,0):
            images,index = data
            y =model(images)
            _,a = torch.max(y,1)
            correct += (a == index).sum().item()
            total += index.size(0)
        print('正确率 = ', correct / total)

if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()











评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值