0. 前言
本文讲如何使用Pytorch实现残差网络,残差网络的结构如图:

输入x,输出output=F(x)+x:F(x)为输入经过卷积运算后的结果,x为输入本身。实际上就是将输出结果再次加上了本身。
1. 如何理解"+"号运算?
这里的加法实际上就是对应位的数值相加,比如[1,2,3]和[4,5,6]相加的结果为[5,7,9],并不是通道的拼接,如此便要求x和F(x)应当拥有相同的shape。我们通过实验来看加号(F(x)具体运算省略):
定义输入x:

F(x)为:

x+F(x)的结果为:

2. 要求:x和F(x)的shape一样
我们知道➕是对应位相加,则要求两者shape一样才能正确相加。一般F(x)经过卷积运算后,与x的通道数是不一样的,所以我们在将两者相加是,需要将x的通道数变为和F(x)一样,这里同样通过卷积操作完成。具体看下面代码。
3. 代码实现
3.1 定义ResBlock网络
import torch
import torch.nn as nn
class ResBlock(nn.Module):
def __init__(self, input_channels, out_channels, kernel_size):
super(ResBlock, self).__init__()
self.function=nn.Sequential(
nn.Conv2d(input_channels, out_channels, kernel_size, padding=1),
nn.Conv2d(out_channels, out_channels, kernel_size, padding=1)
)
self.downsample=nn.Sequential(
nn.Conv2d(input_channels,out_channels,kernel_size,padding=1)
)
def forward(self, x):
identify = x
identify=self.downsample(identify)
f = self.function(x)
out = f + identify
return out
model=ResBlock(3,8,3) # 定义一个输入通道为3,输出通道为8,卷积核size为3的ResBlock网络
input=torch.randn(1,3,100,100) # 产生一个shape=[1,3,100,100]的数据作为输入
print(model(input).shape) # 推理,并打印推理结果的shape

实现代码和网络结构图如上,downsample的功能就是将x原本的3通道变为8通道,以此保证相加时两个形状一致。downsample本身实际上就是一个卷积运算,因为卷积运算是可以增加通道数的。
本文详细介绍了如何在PyTorch中运用残差网络,重点讲解了 + 运算在保持通道数一致以便于相加的重要性,以及如何通过卷积操作调整输入和输出的通道数。通过实例展示了ResBlock的定义和代码实现,以及网络结构的设计。
4675

被折叠的 条评论
为什么被折叠?



