倒置残差块(Inverted Residual Block),是MobileNetV2网络中提升效率的关键结构。
类定义和构造函数
class IRBlock(nn.Module):
def __init__(self, inp, oup, stride=1, expansion=4):
IRBlock
类继承自nn.Module
,是一个神经网络模块。__init__
方法是类的构造函数,用于初始化实例。inp
: 输入通道数。oup
: 输出通道数。stride
: 卷积的步长,决定了输出特征图的大小。expansion
: 扩展因子,用于控制内部隐藏层通道数的扩展。
内部变量和断言
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expansion)
self.use_res_connect = self.stride == 1 and inp == oup
- 存储步长信息。
- 断言确保步长是1或2,这是典型的设计选择,用于控制特征图的下采样。
- 计算隐藏层的通道数,通过输入通道数乘以扩展因子。
- 判断是否使用残差连接,当步长为1且输入输出通道数相等时使用。
在编程中,断言(
assert
)是一种检查代码是否满足某些条件的方式。如果条件不成立,程序会抛出异常。在 IRBlock
类中,使用断言有几个目的:
assert stride in [1, 2]
这行代码确保传递给类的
stride
参数只能是1或2。步长(stride)在卷积神经网络中是一个关键参数,它决定了卷积层如何在输入数据上移动。步长为1意味着卷积核每次移动一个像素,