一级目录
下面这段代码中,super(BasicBlock, self).__init__()
这一行的解释
在Python中,super()
函数用于调用父类(超类)的一个方法。super(BasicBlock, self).__init__()
这行代码的意思是调用BasicBlock
类的父类nn.Module
的构造函数__init__()
。
第一个参数BasicBlock
是指明要调用哪个类的父类的方法,这里是BasicBlock
的父类。
第二个参数self
是指要将哪个实例作为调用父类方法时的self
参数传递进去,这里是将创建的BasicBlock
实例作为self
参数传递给父类的__init__()
方法。
class BasicBlock(nn.Module):
def __init__(self, inchannel, outchannel, stride=1):
super(BasicBlock, self).__init__()
self.left = nn.Sequential(
nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(outchannel),
nn.ReLU(inplace=True), #inplace=True表示进行原地操作,一般默认为False,表示新建一个变量存储操作
nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(outchannel)
)
self.shortcut = nn.Sequential()
#论文中模型架构的虚线部分,需要下采样
if stride != 1 or inchannel != outchannel:
self.shortcut = nn.Sequential(
nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(outchannel)
)
def forward(self, x):
out = self.left(x) #这是由于残差块需要保留原始输入
out += self.shortcut(x)#这是ResNet的核心,在输出上叠加了输入x
out = F.relu(out)
return out