MDEQ代码解读模型部分
MDEQ简介
MDEQ模型来自于发表在NeurIPS 2020的论文Multiscale Deep Equilibrium Models,是对DEQ模型的扩展,将原本用于序列数据的DEQ模型,通过多尺度扩展到视觉任务上,包括图像分类和语义分割,并且取得了不错的结果。下面我们就来看一下MDEQ模型以及模型部分的代码。
MDEQ模型部分代码解读
残差块
残差块对应文章中的图2, Input Injection是图片,BasicBlock 作用在隐变量 z z z上。
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, n_big_kernels=0, dropout=0.0, wnorm=False):
"""
A canonical residual block with two 3x3 convolutions and an intermediate ReLU. Corresponds to Figure 2
in the paper.
"""
super(BasicBlock, self).__init__()
conv1 = conv5x5 if n_big_kernels >= 1 else conv3x3
conv2 = conv5x5 if n_big_kernels >= 2 else conv3x3
inner_planes = int(DEQ_EXPAND*planes)
self.conv1 = conv1(inplanes, inner_planes)
self.gn1 = nn.GroupNorm(NUM_GROUPS, inner_planes, affine=BLOCK_GN_AFFINE)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv2(inner_planes, planes)
self.gn2 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
self.gn3 = nn.GroupNorm(NUM_GROUPS, planes, affine=BLOCK_GN_AFFINE)
self.relu3 = nn.ReLU(inplace=True)
self.downsample = downsample
self.drop = VariationalHidDropout2d(dropout)
if wnorm: self._wnorm()
def _wnorm(self):
"""
Register weight normalization
"""
self.conv1, self.conv1_fn = weight_norm(self.conv1, names=['weight'], dim=0)
self.conv2, self.conv2_fn = weight_norm(self.conv2, names=['weight'], dim=0)
def _reset(self, bsz, d, H, W):
"""
Reset dropout mask and recompute weight via weight normalization
"""
if 'conv1_fn' in self.__dict__:
self.conv1_fn.reset(self.conv1)
if 'conv2_fn' in self.__dict__:
self.conv2_fn.reset(self.conv2)
self.drop.reset_mask(bsz, d, H, W)
def forward(self, x, injection=None):
if injection is None: injection = 0
residual = x
out = self.relu(self.gn1(self.conv1(x)))
out = self.drop(self.conv2(out)) + injection
out = self.gn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.gn3(self.relu3(out))
return out
blocks_dict = { 'BASIC': BasicBlock }
分支网络
将不同分辨率的残差块串联起来。
class BranchNet(nn.Module):
def __init__(self, blocks):
"""
The residual block part of each resolution stream
"""
super().__init__()
self.blocks = blocks
def forward(self, x, injection=None):
blocks = self.blocks
y = blocks[0](x, injection)
for i in range(1, len(blocks)):
y = blocks[i](y)
return y
下采样模块
class DownsampleModule(nn.Module):
def __init__(self, num_channels, in_res, out_res):
"""
A downsample step from resolution j (with in_res) to resolution i (with out_res). A series of 2-strided convolutions.
"""
super(DownsampleModule, self).__init__()
# downsample (in_res=j, out_res=i)
convs = []
inp_chan = num_channels[in_res]
out_chan = num_channels[out_res]
self.level_diff = level_diff = out_res - in_res
kwargs = {"kernel_size": 3, "stride": 2, "padding": 1, "bias": False}
for k in range(level_diff):
intermediate_out = out_chan if k == (level_diff-1) else inp_chan
components = [('conv', nn.Conv2d(inp_chan, intermediate_out, **kwargs)),
('gnorm', nn.GroupNorm(NUM_GROUPS, intermediate_out, affine=FUSE_GN_AFFINE))]
if k != (level_diff-1):
components.append(('relu', nn.ReLU(inplace=True)))
convs.append(nn.Sequential(OrderedDict(components)))
self.net = nn.Sequential(*convs)
def forward(self, x):
return self.net(x)
上采样模块
上采样部分用 1 × 1 1 \times 1 1×1的卷积和差值来完成。
class UpsampleModule(nn.Module):
def __init__(self, num_channels, in_res, out_res):
"""
An upsample step from resolution j (with in_res) to resolution i (with out_res).
Simply a 1x1 convolution followed by an interpolation.
"""
super(UpsampleModule, self).__init__()
# upsample (in_res=j, out_res=i)
inp_chan = num_channels[in_res]
out_chan = num_channels[out_res]
self.level_diff = level_diff &#