Learning Dynamic Routing for Semantic Segmentation (CVPR 2020 Open Access)
github源码
旷视科技-知乎-论文解读
Dynamic Routing会自适应地生成不同的结构进行特征编码,网络可以将不同尺寸的物体(或背景)分配到对应分辨率的层级上,以实现有针对性的特征变换。,相较于Auto Deeplab,Dynamic Routing支持多路径链接、跳跃连接。
之前没有阅读过动态路径选择的模型,在阅读论文时有很多不理解的地方,拿源码来学习一下,如果有不对的地方还请大家指出,欢迎交流(,ԾㅂԾ,)。
有关budget constrain
C
(
N
o
d
e
s
l
)
=
C
(
C
e
l
l
s
l
)
+
C
(
G
a
t
e
s
l
)
+
C
(
T
r
a
n
s
s
l
)
=
m
a
x
(
α
s
l
)
∑
O
i
∈
O
C
(
O
i
)
+
C
(
G
a
t
e
s
l
)
+
∑
j
α
s
→
j
l
C
(
T
s
→
j
)
C
(
S
p
a
c
e
)
=
∑
l
≤
L
∑
s
≤
1
/
4
C
(
N
o
d
e
s
l
)
L
C
=
(
C
(
S
p
a
c
e
)
/
C
−
μ
)
2
L
=
λ
1
L
N
+
λ
2
L
C
\begin{array}{l}\mathcal{C}(Node^l_s)&= \mathcal{C}(Cell^l_s) + \mathcal{C}(Gate^l_s) + \mathcal{C}(Trans^l_s)\\& = max(α_s^l ) \sum_{O^i\in\mathcal{O}}\mathcal{C}(O^i) + \mathcal{C}(Gate^l_s)+\sum_j α_{s→j}^l\mathcal{C}(\mathcal{T}_{s→j})\end{array} \\\mathcal{C}(Space)=\sum_{l\le L}\sum_{s\le 1/4}\mathcal{C}(Node_s^l)\\ \mathcal{L}_C=(\mathcal{C}(Space)/C-\mu)^2\\ \mathcal{L}=\lambda_1\mathcal{L}_N+\lambda_2\mathcal{L}_C
C(Nodesl)=C(Cellsl)+C(Gatesl)+C(Transsl)=max(αsl)∑Oi∈OC(Oi)+C(Gatesl)+∑jαs→jlC(Ts→j)C(Space)=∑l≤L∑s≤1/4C(Nodesl)LC=(C(Space)/C−μ)2L=λ1LN+λ2LC
上式中,
C
\mathcal{C}
C为相关运算的FLOPs,
C
C
C为网络真实的FLOPs,最后的损失函数包括budget constraint项和语义分割损失项,budget constraint与各个Cell的
α
\alpha
α(Soft Conditional Gate,用于控制特征输出的程度)相关。
此外,
λ
2
=
m
i
n
(
1
,
(
c
u
r
r
e
n
t
_
s
t
e
p
/
t
o
t
a
l
_
s
t
e
p
−
U
N
U
P
D
A
T
E
_
R
A
T
E
)
/
0.2
)
\lambda_2=min(1,(current\_step/total\_step-UNUPDATE\_RATE)/0.2)
λ2=min(1,(current_step/total_step−UNUPDATE_RATE)/0.2)
在模型前向传播的过程中,会顺便计算相关的一些FLOPs。
dynamic4seg.py
用于语义分割的模型的整体架构可以在dl_lib/modeling/meta_arch/dynamic4seg.py中找到,这个文件中定义了整个语义分割的流程。
在获取到batch_inputs之后,对input_batch进行归一化处理,然后输入到backbone,得到中间层的特征features、用于budget constrain的expt_flops、用于评估模型计算成本的backbone的FLOPs real_flops;然后将features输入到sem_seg_head中进行上采样,得到分割结果和损失loss;如果是在训练过程中,则将expt_flops加入到最终的损失loss中,用于反向传播,约束模型的复杂度;如果是在inference阶段,则上采样(双线性插值)到原图大小,返回最终分割结果(输入的图像往往会进行resize,与原图大小有差异)。
backbone、sem_seg_head的结构分别定义在dl_lib/modeling/dynamic_arch/dynamic_backbone.py和当前的dynamic4seg.py中。
模型初始化参数可参考playground/Dynamic/目录下的各个模型的config文件
class DynamicNet4Seg(nn.Module):
"""
This module implements Dynamic Network for Semantic Segmentation.
"""
def __init__(self, cfg):
super().__init__()
self.constrain_on = cfg.MODEL.BUDGET.CONSTRAIN
self.unupdate_rate = cfg.MODEL.BUDGET.UNUPDATE_RATE
self.device = torch.device(cfg.MODEL.DEVICE)
self.backbone = cfg.build_backbone(cfg)
self.sem_seg_head = cfg.build_sem_seg_head(
cfg, self.backbone.output_shape())
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(
-1, 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(
-1, 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
self.budget_constrint = BudgetConstraint(cfg)
self.to(self.device)
def forward(self, batched_inputs, step_rate=0.0):
"""
Args:
batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
Each item in the list contains the inputs for one image.
step_rate: a float, calculated by current_step/total_step,
This parameter is used for Scheduled Drop Path.
For now, each item in the list is a dict that contains:
image: Tensor, image in (C, H, W) format.
sem_seg: semantic segmentation ground truth
Other information that's included in the original dicts, such as:
"height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
Returns:
list[dict]: Each dict is the output for one input image.
The dict contains one key "sem_seg" whose value is a
Tensor of the output resolution that represents the
per-pixel segmentation prediction.
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = [self.normalizer(x) for x in images]
images = ImageList.from_tensors(images,
self.backbone.size_divisibility)
features, expt_flops, real_flops = self.backbone(
images.tensor, step_rate)
if "sem_seg" in batched_inputs[0]:
targets = [x["sem_seg"].to(self.device) for x in batched_inputs]
targets = ImageList.from_tensors(
targets, self.backbone.size_divisibility,
self.sem_seg_head.ignore_value).tensor
else:
targets = None
results, losses = self.sem_seg_head(features, targets)
# calculate flops
real_flops += self.sem_seg_head.flops
flops = {'real_flops': real_flops, 'expt_flops': expt_flops}
# use budget constraint for training
if self.training:
if self.constrain_on and step_rate >= self.unupdate_rate:
warm_up_rate = min(
1.0, (step_rate - self.unupdate_rate) / 0.02
)
loss_budget = self.budget_constrint(
expt_flops, warm_up_rate=warm_up_rate
)
losses.update({'loss_budget': loss_budget})
return losses, flops
processed_results = []
for result, input_per_image, image_size in zip(results, batched_inputs,
images.image_sizes):
height = input_per_image.get("height")
width = input_per_image.get("width")
r = sem_seg_postprocess(result, image_size, height, width)
processed_results.append({"sem_seg": r, "flops": flops})
return processed_results
sem_seg_head主要用于对提取到的多尺度特征进行上采样和特征融合,的计算过程如下:
根据输入feature的个数初始化模型
C
o
n
v
1
×
1
Conv1\times 1
Conv1×1+bach_norm+relu的decoder的层数,在将底层特征用decoder layer做解码后,使用双线性插值将特征图上采样到上一层的大小并按位相加进行融合,最后使用
C
o
n
v
3
×
3
Conv3\times 3
Conv3×3+双线性插值上采样到原图大小得到result,如果是训练阶段,计算mean pixel cross entropy loss,返回分割结果和loss。
class SemSegDecoderHead(nn.Module):
"""
This module implements simple decoder head for Semantic Segmentation.
It creats decoder on top of the dynamic backbone.
"""
def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
super().__init__()
# fmt: off
self.in_features = cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
feature_strides = {k: v.stride for k, v in input_shape.items()} # noqa:F841
feature_channels = {k: v.channels for k, v in input_shape.items()}
feature_resolution = {
k: np.array([v.height, v.width])
for k, v in input_shape.items()
}
self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE
num_classes = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
norm = cfg.MODEL.SEM_SEG_HEAD.NORM
self.loss_weight = cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT
self.cal_flops = cfg.MODEL.CAL_FLOPS
self.real_flops = 0.0
# fmt: on
self.layer_decoder_list = nn.ModuleList()
# set affine in BatchNorm
if 'Sync' in norm:
affine = True
else:
affine = False
# use simple decoder
for _feat in self.in_features:
res_size = feature_resolution[_feat]
in_channel = feature_channels[_feat]
if _feat == 'layer_0':
out_channel = in_channel
else:
out_channel = in_channel // 2
conv_1x1 = Conv2d(in_channel,
out_channel,
kernel_size=1,
stride=1,
padding=0,
bias=False,
norm=get_norm(norm, out_channel),
activation=nn.ReLU())
self.real_flops += cal_op_flops.count_ConvBNReLU_flop(
res_size[0],
res_size[1],
in_channel,
out_channel, [1, 1],
is_affine=affine)
self.layer_decoder_list.append(conv_1x1)
# using Kaiming init
for layer in self.layer_decoder_list:
weight_init.kaiming_init_module(layer, mode='fan_in')
in_channel = feature_channels['layer_0']
# the output layer
self.predictor = Conv2d(in_channels=in_channel,
out_channels=num_classes,
kernel_size=3,
stride=1,
padding=1)
self.real_flops += cal_op_flops.count_Conv_flop(
feature_resolution['layer_0'][0], feature_resolution['layer_0'][1],
in_channel, num_classes, [3, 3])
# using Kaiming init
weight_init.kaiming_init_module(self.predictor, mode='fan_in')
def forward(self, features, targets=None):
pred, pred_output = None, None
for _index in range(len(self.in_features)):
out_index = len(self.in_features) - _index - 1
out_feat = features[self.in_features[out_index]]
if out_index <= 2:
out_feat = pred + out_feat
pred = self.layer_decoder_list[out_index](out_feat)
if out_index > 0:
pred = F.interpolate(input=pred,
scale_factor=2,
mode='bilinear',
align_corners=False)
else:
pred_output = pred
# pred output
pred_output = self.predictor(pred_output)
pred_output = F.interpolate(input=pred_output,
scale_factor=4,
mode='bilinear',
align_corners=False)
if self.training:
losses = {}
losses["loss_sem_seg"] = (
F.cross_entropy(
pred_output, targets, reduction="mean",
ignore_index=self.ignore_value
) * self.loss_weight
)
return [], losses
else:
return pred_output, {}
@property
def flops(self):
return self.real_flops
dynamic_backbone.py
dl_lib/modeling/dynamic_arch/dynamic_backbone.py 中定义了模型的backbone结构和前向传播过程。初始化参数可参考playground/Dynamic/目录下的各个模型的config文件。
input首先通过STEM模块(3层
C
o
n
v
3
×
3
Conv3\times 3
Conv3×3)下采样到输入的1/4,然后用一个Cell单元对STEM得到的特征进行初始化,便于后面进行Cell单元的计算。当前层
l
l
l的Cell单元的输入为上一层输入当前层的特征
Y
Y
Y之和,
X
s
l
=
Y
s
/
2
l
−
1
+
Y
s
l
−
1
+
Y
2
s
l
−
1
X_s^l=Y^{l-1}_{s/2}+Y^{l-1}_{s}+Y^{l-1}_{2s}
Xsl=Ys/2l−1+Ysl−1+Y2sl−1,然后逐单元计算输出。
class DynamicNetwork(Backbone):
"""
This module implements Dynamic Routing Network.
It creates dense connected network on top of some input feature maps.
"""
def __init__(
self, init_channel, input_shape, cell_num_list, layer_num,
ext_layer=None, norm="", cal_flops=True, cell_type='',
max_stride=32, sep_stem=True, using_gate=False,
small_gate=False, gate_bias=1.5, drop_prob=0.0,
):
super(DynamicNetwork, self).__init__()
# set affine in BatchNorm
if 'Sync' in norm:
self.affine = True
else:
self.affine = False
# set scheduled drop path
self.drop_prob = drop_prob
if self.drop_prob > 0.0001:
self.drop_path = True
else:
self.drop_path = False
self.cal_flops = cal_flops
self._size_divisibility = max_stride
input_res = np.array(input_shape[1:3])
self.stem = DynamicStem(
3, out_channels=init_channel, input_res=input_res,
sept_stem=sep_stem, norm=norm, affine=self.affine
)
self.stem_flops = self.stem.flops
self._out_feature_strides = {"stem": self.stem.stride}
self._out_feature_channels = {"stem": self.stem.out_channels}
self._out_feature_resolution = {"stem": self.stem.out_resolution}
assert self.stem.out_channels == init_channel
self.all_cell_list = nn.ModuleList()
self.all_cell_type_list = []
self.cell_num_list = cell_num_list[:layer_num]
self._out_features = []
# using the initial layer
input_res = input_res // self.stem.stride
in_channel = out_channel = init_channel
self.init_layer = Cell(
C_in=in_channel, C_out=out_channel, norm=norm, allow_up=False,
allow_down=True, input_size=input_res, cell_type=cell_type,
cal_flops=False, using_gate=using_gate, small_gate=small_gate,
gate_bias=gate_bias, affine=self.affine
)
# add cells in each layer
for layer_index in range(len(self.cell_num_list)):
layer_cell_list = nn.ModuleList()
layer_cell_type = []
for cell_index in range(self.cell_num_list[layer_index]):
# channel multi, when stride:4 -> channel:C, stride:8 -> channel:2C ...
channel_multi = pow(2, cell_index)
in_channel_cell = in_channel * channel_multi
# add res and dim switch to each cell
allow_up = True
allow_down = True
# add res up and dim down by 2
if cell_index == 0 or layer_index == layer_num - 1:
allow_up = False
# dim down and resolution up by 2
if cell_index == 3 or layer_index == layer_num - 1:
allow_down = False
res_size = input_res // channel_multi
layer_cell_list.append(
Cell(
C_in=in_channel_cell, C_out=in_channel_cell, norm=norm,
allow_up=allow_up, allow_down=allow_down,
input_size=res_size, cell_type=cell_type,
cal_flops=cal_flops, using_gate=using_gate,
small_gate=small_gate, gate_bias=gate_bias,
affine=self.affine
)
)
# allow dim change in each aggregation
dim_up, dim_down, dim_keep = False, False, True
# dim up and resolution down by 2
if cell_index > 0:
dim_up = True
# dim down and resolution up by 2
if (cell_index < self.cell_num_list[layer_index] - 1) and layer_index > 2:
dim_down = True
elif (cell_index < self.cell_num_list[layer_index] - 2) and layer_index <= 2:
dim_down = True
# dim keep unchanged
if layer_index <= 2 and cell_index == self.cell_num_list[layer_index] - 1:
dim_keep = False
# allowed cell operations
layer_cell_type.append([dim_up, dim_keep, dim_down])
if layer_index == len(self.cell_num_list) - 1:
name = 'layer_' + str(cell_index)
self._out_feature_strides[name] = channel_multi * self.stem.stride
self._out_feature_channels[name] = in_channel_cell
self._out_feature_resolution[name] = res_size
self._out_features.append(name)
self.all_cell_list.append(layer_cell_list)
self.all_cell_type_list.append(layer_cell_type)
@property
def size_divisibility(self):
return self._size_divisibility
def forward(self, x, step_rate=0.0):
h_l1 = self.stem(x)
# the initial layer
h_l1_list, h_beta_list, trans_flops, trans_flops_real = self.init_layer(h_l1=h_l1)
prev_beta_list, prev_out_list = [h_beta_list], [h_l1_list] # noqa: F841
prev_trans_flops, prev_trans_flops_real = [trans_flops], [trans_flops_real]
# build forward outputs
cell_flops_list, cell_flops_real_list = [], []
for layer_index in range(len(self.cell_num_list)):
layer_input, layer_output = [], []
layer_trans_flops, layer_trans_flops_real = [], []
flops_in_expt_list, flops_in_real_list = [], []
layer_rate = (layer_index + 1) / float(len(self.cell_num_list))
# aggregate cell input
for cell_index in range(len(self.all_cell_type_list[layer_index])):
cell_input, trans_flops_input, trans_flops_real_input = [], [], []
if self.all_cell_type_list[layer_index][cell_index][0]:
cell_input.append(prev_out_list[cell_index - 1][2][0])
trans_flops_input.append(prev_trans_flops[cell_index - 1][2][0])
trans_flops_real_input.append(prev_trans_flops_real[cell_index - 1][2][0])
if self.all_cell_type_list[layer_index][cell_index][1]:
cell_input.append(prev_out_list[cell_index][1][0])
trans_flops_input.append(prev_trans_flops[cell_index][1][0])
trans_flops_real_input.append(prev_trans_flops_real[cell_index][1][0])
if self.all_cell_type_list[layer_index][cell_index][2]:
cell_input.append(prev_out_list[cell_index + 1][0][0])
trans_flops_input.append(prev_trans_flops[cell_index + 1][0][0])
trans_flops_real_input.append(prev_trans_flops_real[cell_index + 1][0][0])
h_l1 = sum(cell_input)
# calculate input for gate
layer_input.append(h_l1)
# calculate FLOPs input
flops_in_expt = sum(_flops for _flops in trans_flops_input)
flop_in_real = sum(_flops for _flops in trans_flops_real_input)
flops_in_expt_list.append(flops_in_expt)
flops_in_real_list.append(flop_in_real)
# calculate each cell
for _cell_index in range(len(self.all_cell_type_list[layer_index])):
if self.cal_flops:
cell_output, gate_weights_beta, cell_flops, \
cell_flops_real, trans_flops, trans_flops_real = \
self.all_cell_list[layer_index][_cell_index](
h_l1=layer_input[_cell_index],
flops_in_expt=flops_in_expt_list[_cell_index],
flops_in_real=flops_in_real_list[_cell_index],
is_drop_path=self.drop_path, drop_prob=self.drop_prob,
layer_rate=layer_rate, step_rate=step_rate
)
# calculate real flops
cell_flops_list.append(cell_flops)
cell_flops_real_list.append(cell_flops_real)
else:
cell_output, gate_weights_beta, trans_flops, trans_flops_real = \
self.all_cell_list[layer_index][_cell_index](
h_l1=layer_input[_cell_index],
flops_in_expt=flops_in_expt_list[_cell_index],
flops_in_real=flops_in_real_list[_cell_index],
is_drop_path=self.drop_path, drop_prob=self.drop_prob,
layer_rate=layer_rate, step_rate=step_rate
)
layer_output.append(cell_output)
# update trans flops output
layer_trans_flops.append(trans_flops)
layer_trans_flops_real.append(trans_flops_real)
# update layer output
prev_out_list = layer_output
prev_trans_flops = layer_trans_flops
prev_trans_flops_real = layer_trans_flops_real
final_out_list = [prev_out_list[_i][1][0] for _i in range(len(prev_out_list))]
final_out_dict = dict(zip(self._out_features, final_out_list))
if self.cal_flops:
all_cell_flops = torch.mean(sum(cell_flops_list))
all_flops_real = torch.mean(sum(cell_flops_real_list)) + self.stem_flops
else:
all_cell_flops, all_flops_real = None, None
return final_out_dict, all_cell_flops, all_flops_real
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name],
height=self._out_feature_resolution[name][0],
width=self._out_feature_resolution[name][0],
stride=self._out_feature_strides[name]
)
for name in self._out_features
}
dynamic_cell.py
dl_lib/modeling/dynamic_arch/dynamic_cell.py定义了模型中Cell单元的计算流程。
首先根据
G
s
l
=
F
(
w
s
,
2
l
,
G
(
σ
(
N
(
F
(
w
s
,
1
l
,
X
s
l
)
)
)
)
)
+
β
s
l
G^l_s = \mathcal{F}(\mathcal{w}_{s,2}^l,\mathcal{G}(σ(\mathcal{N}(\mathcal{F}(w_{s,1}^l, X^l_s))))) + β_s^l
Gsl=F(ws,2l,G(σ(N(F(ws,1l,Xsl)))))+βsl计算soft gate
α
s
l
\alpha^l_s
αsl。
如果
α
s
l
\alpha^l_s
αsl足够小,将当前cell的相关值置零;
否则使用
C
e
l
l
O
p
e
r
a
t
i
o
n
Cell Operation
CellOperation(
S
e
p
C
o
n
v
3
×
3
SepConv3\times 3
SepConv3×3,深度可分离卷积,
C
o
n
v
3
×
3
+
C
o
n
v
1
×
1
+
C
o
n
v
3
×
3
+
C
o
n
v
1
×
1
Conv3\times 3+Conv1\times 1+Conv3\times 3+Conv1\times 1
Conv3×3+Conv1×1+Conv3×3+Conv1×1)从输入数据中提取特征
H
s
l
H_s^l
Hsl,然后对当前层的特征
H
s
l
H_s^l
Hsl分别进行上采样、保持尺度、下采样的操作,然后计算当前单元的输出
Y
s
j
=
α
s
→
j
l
T
s
→
j
(
H
s
l
)
Y_s^j=\alpha_{s→j}^l\mathcal{T}_{s→j}(H_s^l)
Ysj=αs→jlTs→j(Hsl)
class Cell(nn.Module):
def __init__(
self, C_in, C_out, norm, allow_up, allow_down, input_size,
cell_type, cal_flops=True, using_gate=False,
small_gate=False, gate_bias=1.5, affine=True
):
super(Cell, self).__init__()
self.channel_in = C_in
self.channel_out = C_out
self.allow_up = allow_up
self.allow_down = allow_down
self.cal_flops = cal_flops
self.using_gate = using_gate
self.small_gate = small_gate
self.cell_ops = Mixed_OP(
inplanes=self.channel_in, outplanes=self.channel_out,
stride=1, cell_type=cell_type, norm=norm,
affine=affine, input_size=input_size
)
self.cell_flops = self.cell_ops.flops
# resolution keep
self.res_keep = nn.ReLU()
self.res_keep_flops = cal_op_flops.count_ReLU_flop(
input_size[0], input_size[1], self.channel_out
)
# resolution up and dim down
if self.allow_up:
self.res_up = nn.Sequential(
nn.ReLU(),
Conv2d(
self.channel_out, self.channel_out // 2, kernel_size=1,
stride=1, padding=0, bias=False,
norm=get_norm(norm, self.channel_out // 2),
activation=nn.ReLU()
)
)
# calculate Flops
self.res_up_flops = cal_op_flops.count_ReLU_flop(
input_size[0], input_size[1], self.channel_out
) + cal_op_flops.count_ConvBNReLU_flop(
input_size[0], input_size[1], self.channel_out,
self.channel_out // 2, [1, 1], is_affine=affine
)
# using Kaiming init
weight_init.kaiming_init_module(self.res_up, mode='fan_in')
# resolution down and dim up
if self.allow_down:
self.res_down = nn.Sequential(
nn.ReLU(),
Conv2d(
self.channel_out, 2 * self.channel_out,
kernel_size=1, stride=2, padding=0, bias=False,
norm=get_norm(norm, 2 * self.channel_out),
activation=nn.ReLU()
)
)
# calculate Flops
self.res_down_flops = cal_op_flops.count_ReLU_flop(
input_size[0], input_size[1], self.channel_out
) + cal_op_flops.count_ConvBNReLU_flop(
input_size[0], input_size[1], self.channel_out,
2 * self.channel_out, [1, 1], stride=2, is_affine=affine
)
# using Kaiming init
weight_init.kaiming_init_module(self.res_down, mode='fan_in')
if self.allow_up and self.allow_down:
self.gate_num = 3
elif self.allow_up or self.allow_down:
self.gate_num = 2
else:
self.gate_num = 1
if self.using_gate:
self.gate_conv_beta = nn.Sequential(
Conv2d(
self.channel_in, self.channel_in // 2, kernel_size=1,
stride=1, padding=0, bias=False,
norm=get_norm(norm, self.channel_in // 2),
activation=nn.ReLU()
),
nn.AdaptiveAvgPool2d((1, 1)),
Conv2d(
self.channel_in // 2, self.gate_num, kernel_size=1,
stride=1, padding=0, bias=True
)
)
if self.small_gate:
input_size = input_size // 4
self.gate_flops = cal_op_flops.count_ConvBNReLU_flop(
input_size[0], input_size[1], self.channel_in,
self.channel_in // 2, [1, 1], is_affine=affine
) + cal_op_flops.count_Pool2d_flop(
input_size[0], input_size[1], self.channel_in // 2, [1, 1], 1
) + cal_op_flops.count_Conv_flop(
1, 1, self.channel_in // 2, self.gate_num, [1, 1]
)
# using Kaiming init and predefined bias for gate
weight_init.kaiming_init_module(
self.gate_conv_beta, mode='fan_in', bias=gate_bias
)
else:
self.register_buffer(
'gate_weights_beta', torch.ones(1, self.gate_num, 1, 1).cuda()
)
self.gate_flops = 0.0
def forward(
self, h_l1, flops_in_expt=None, flops_in_real=None,
is_drop_path=False, drop_prob=0.0,
layer_rate=0.0, step_rate=0.0
):
"""
:param h_l1: # the former hidden layer output
:return: current hidden cell result h_l
"""
drop_cell = False
# drop the cell if input type is float
if not isinstance(h_l1, float):
# calculate soft conditional gate
if self.using_gate:
if self.small_gate:
h_l1_gate = F.interpolate(
input=h_l1, scale_factor=0.25,
mode='bilinear', align_corners=False
)
else:
h_l1_gate = h_l1
gate_feat_beta = self.gate_conv_beta(h_l1_gate)
gate_weights_beta = soft_gate(gate_feat_beta)
else:
gate_weights_beta = self.gate_weights_beta
else:
drop_cell = True
# use for inference
if not self.training:
if not drop_cell:
drop_cell = gate_weights_beta.sum() < 0.0001
if drop_cell:
result_list = [[0.0], [h_l1], [0.0]]
weights_list_beta = [[0.0], [0.0], [0.0]]
trans_flops_expt = [[0.0], [0.0], [0.0]]
trans_flops_real = [[0.0], [0.0], [0.0]]
if self.cal_flops:
h_l_flops = flops_in_expt
h_l_flops_real = flops_in_real + self.gate_flops
return (
result_list, weights_list_beta, h_l_flops,
h_l_flops_real, trans_flops_expt, trans_flops_real
)
else:
return (
result_list, weights_list_beta,
trans_flops_expt, trans_flops_real
)
h_l = self.cell_ops(h_l1, is_drop_path, drop_prob, layer_rate, step_rate)
# resolution and dimension change
# resolution: [up, keep, down]
h_l_keep = self.res_keep(h_l)
gate_weights_beta_keep = gate_weights_beta[:, 0].unsqueeze(-1)
# using residual connection if drop cell
gate_mask = (gate_weights_beta.sum(dim=1, keepdim=True) < 0.0001).float()
result_list = [[], [gate_mask * h_l1 + gate_weights_beta_keep * h_l_keep], []]
weights_list_beta = [[], [gate_mask * 1.0 + gate_weights_beta_keep], []]
# calculate flops for keep res
gate_mask_keep = (gate_weights_beta_keep > 0.0001).float()
trans_flops_real = [[], [gate_mask_keep * self.res_keep_flops], []]
# calculate trans flops
trans_flops_expt = [[], [self.res_keep_flops * gate_weights_beta_keep], []]
if self.allow_up:
h_l_up = self.res_up(h_l)
h_l_up = F.interpolate(
input=h_l_up, scale_factor=2, mode='bilinear', align_corners=False
)
gate_weights_beta_up = gate_weights_beta[:, 1].unsqueeze(-1)
result_list[0].append(h_l_up * gate_weights_beta_up)
weights_list_beta[0].append(gate_weights_beta_up)
trans_flops_expt[0].append(self.res_up_flops * gate_weights_beta_up)
# calculate flops for up res
gate_mask_up = (gate_weights_beta_up > 0.0001).float()
trans_flops_real[0].append(gate_mask_up * self.res_up_flops)
if self.allow_down:
h_l_down = self.res_down(h_l)
gate_weights_beta_down = gate_weights_beta[:, -1].unsqueeze(-1)
result_list[2].append(h_l_down * gate_weights_beta_down)
weights_list_beta[2].append(gate_weights_beta_down)
trans_flops_expt[2].append(self.res_down_flops * gate_weights_beta_down)
# calculate flops for down res
gate_mask_down = (gate_weights_beta_down > 0.0001).float()
trans_flops_real[2].append(gate_mask_down * self.res_down_flops)
if self.cal_flops:
cell_flops = gate_weights_beta.max(dim=1, keepdim=True)[0] * self.cell_flops
cell_flops_real = (
gate_weights_beta.sum(dim=1, keepdim=True) > 0.0001
).float() * self.cell_flops
h_l_flops = cell_flops + flops_in_expt
h_l_flops_real = cell_flops_real + flops_in_real + self.gate_flops
return (
result_list, weights_list_beta, h_l_flops,
h_l_flops_real, trans_flops_expt, trans_flops_real
)
else:
return result_list, weights_list_beta, trans_flops_expt, trans_flops_real