PyTorch 深度学习实践 第11讲(ResidualBlock的使用,解决梯度消失)

说明:1、要解决的问题:梯度消失

           2、跳连接,H(x) = F(x) + x,张量维度必须一样,加完后再激活。不要做pooling,张量的维度会发生变化。

#实现ResidualBlock model,解决梯度消失问题
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt

batch_size=64
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])

train_dataset=datasets.MNIST(root="dataset/mnist",train='True',download="False",transform=transform)
train_loader=DataLoader(train_dataset,shuffle="True",batch_size=batch_size)
test_dataset=datasets.MNIST(root="dataset/mnist",train="False",download="False",transform=transform)
test_loader=DataLoader(test_dataset,shuffle="False",batch_size=batch_size)

class ResidualBlock(nn.Module):#输入的通道数和输出的通道数需要一样
    def __init__(self,channel):
        super(ResidualBlock,self).__init__()
        self.channel=channel
        self.conv1=nn.Conv2d(channel,channel,kernel_size=3,padding=1)
        self.conv2=nn.Conv2d(channel,channel,kernel_size=3,padding=1)

    def forward(self,x):
        #输入batch_size*16*12*12---(conv1:16*16*3*3,padding=1)-->batch_size*16*12*12---relu---(conv2:16*16*3*3,padding=1)-->batch_size*16*12*12
        y=F.relu(self.conv1(x))
        y=self.conv2(y)
        return F.relu(x+y)#x+y:batch_size*16*12*12+batch_size*16*12*12
                          #x和y的各维度一样大小

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,16,kernel_size=5)
        self.conv2=nn.Conv2d(16,32,kernel_size=5)

        self.residualblock1=ResidualBlock(16)
        self.residualblock2=ResidualBlock(32)

        self.mp=nn.MaxPool2d(2)
        self.fc=nn.Linear(512,10)

    def forward(self,x):
        in_size=x.size(0)
        x=self.mp(F.relu(self.conv1(x)))
        #batch_size*1*28*28---(1*16*5*5)-->batch_size*16*24*24---relu--->---maxpool2d(2)--->batch_size*16*12*12
        x=self.residualblock1(x)#要求:经过residualblock不改变各维度大小
        #输入batch_size*16*12*12:

        x=self.mp(F.relu(self.conv2(x)))
        x=self.residualblock2(x)
        x=x.view(in_size,-1)
        x=self.fc(x)
        return x

model=Net()
criterion=torch.nn.CrossEntropyLoss()#定义损失函数
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.5)#定义优化器,momentum称之为动量,取值范围为0-1,用于加快优化过程,使得梯度更新更加平滑地进行,不容易陷入局部最小值,0.5的取值表示考虑过去一半的梯度

def train(epoch):#训练函数
    running_loss=0.0
    for batch_idx,data in enumerate(train_loader,0):#enumerate将加载的数据集在每次的epoch中生成一个分批次,每个批次给与一个索引,从0开始赋值给batch_id
        inputs,labels=data#给x,y赋值
        optimizer.zero_grad()#训练之前梯度清零
        outputs=model(inputs)#进行预测
        loss=criterion(outputs,labels)#根据损失函数计算损失值
        loss.backward()#反向传递
        optimizer.step()#梯度更新
        running_loss+=loss.item()#对损失值进行求和
        if batch_idx%300==299:#每300个批次进行暑促,输出300个批次的平均损失值
            print("[%d,%5d]loss:%.8f"%(epoch+1,batch_idx+1,running_loss/300))
            running_loss=0#归零,重新进行记录

accuracy_list = []
def test():
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:#只是为了将真实值与预测值进行对比,不需要在分批次
            images,targets=data#给x,y赋值
            output=model(images)#得到预测值
            _,predicted=torch.max(output.data,dim=1)#torch.max()函数,用于在指定维度上找到最大值
            #output.data是一个二维张量,每一行对应一个样本,对应列值代表每个类别的概率值,dim用于指定维度,意思是从每一行中找出值最大的
            #_表示占位符,用于存储最大值,predicted用于存储概率最大的索引,也就是概率最大的列,也就是我们的预测值
            #如果想要获取每一列的最大值,那么dim取0,pytorch中的维度从0开始编号,0表示列,1表示行
            total+=targets.size(0)#得到源数据的条数
            correct+=(predicted==targets).sum().item()#predicted==targets进行逐个元素的张量比较,得到一个布尔型张量,sum()进行布尔张量求和,因为都是张量计算后item()得到数值
        accuracy_list.append(100*correct/total)  # 将准确率添加到列表中
        print("accuracy on test set:%d %%"%(100*correct/total))
if __name__=="__main__":
    for epoch in range(30):
        train(epoch)
        test()

plt.plot(range(1,31), accuracy_list)  # x轴为epoch,y轴为准确率
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title("Accuracy vs. Epoch")
plt.grid(True)
plt.show()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值