问题
out = net(image) # 图像作为输入,经过net做正向传播,得到输出(分类/框/。。。)
你有没有一个疑问,上面这行代码是如何调用forward()函数得到结果的?
我会贴出源码并做解释
解答
一步一步跟踪,net(image)到底经历了什么?(以下引用该开源代码做讲解,其中会做适当简化,以达到说明的目的)
- net的定义
net = RetinaFace()
- RetinaFace类的定义
class RetinaFace(nn.Module):
def __init__(self):
# 定义层结构,举例如下
self.fpn = FPN()
def forward(self, inputs):
out = self.fpn(inputs)
return out
- FPN类的定义
class FPN(nn.Module):
def __init__(self,in_channels_list,out_channels):
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1)
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, str