B站大佬的视频讲解
因为以后可能想用ResNet网络作为backbon,所以仔细阅读了源码,做一下个人笔记。就是给源码写自己看的懂得注释…
import torch.nn as nn
import torch
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_channel, out_channel, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False) ##在ResNet的结构图中可以发现,
self.bn1 = nn.BatchNorm2d(out_channel) ##不同Conv层的featuremap尺寸不一样,
self.relu = nn.ReLU() ##需要每个Conv层的第一层卷积的strdie=2来改变
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None: ##有些输入与输出之间的featuremap尺寸不一样,需要通过下采样来实现。
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out