一、要解决的问题
- 无绿幕人像抠图
- I = α ∗ F + ( 1 − α ) ∗ B I=\alpha *F+(1-\alpha)*B I=α∗F+(1−α)∗B
二、创新点
- 无绿幕、无trimap人像端到端抠图
- SOC模型泛化迁移,OFD视频抠图增强
- Validation Benchmark
三、具体细节
MODNet网络结构如上图所示。主要包括三个自网络:Semantic Branch;Detail Branch;Fusion Branch。
Sematic Branch
Encoder-Decoder结构,采用Mobilenet-v2作为Encoder,并使用channel-wise Attention给Hidden Features添加权重,re-weight后的特征通过上采样-卷积-BN-ReLU套装恢复分辨率到原分辨率的1/8。Sigmoid激活后输出,作为Semantics S p S_p Sp
Detail Branch
Encoder-Decoder结构,输入包括Image以及Semantic Branch不同层的hidden features,通过上采样-卷积-BN-ReLU套装输出原分辨率的detail_alpha图
Fusion Branch
Encoder-Decoder结构,输入包括Semantic Branch以及Detail Branch的hidden features。通过上采样-卷积-BN-ReLU套装恢复到原分辨率,Sigmoid激活后输出,作为最终的 α \alpha α。
四、代码分析
网络结构较为简单,不分析此部分代码。
看一下各部分的损失函数。
Semantic Branch的损失函数,
G
(
α
g
)
G(\alpha_g)
G(αg)表示对gound truth alpha下采样。使用L2 Loss。
Detail Branch的损失函数,使用L1 Loss。
m
d
m_d
md表示边缘区域。
Fusion Branch的损失函数,除了L1损失,还引入合成损失。
整个网络的损失函数:
# forward the model
pred_semantic, pred_detail, pred_matte = modnet(image, False)
# calculate the boundary mask from the trimap
boundaries = (trimap < 0.5) + (trimap > 0.5)
# calculate the semantic loss
gt_semantic = F.interpolate(gt_matte, scale_factor=1/16, mode='bilinear')
gt_semantic = blurer(gt_semantic)
semantic_loss = torch.mean(F.mse_loss(pred_semantic, gt_semantic))
semantic_loss = semantic_scale * semantic_loss
# calculate the detail loss
pred_boundary_detail = torch.where(boundaries, trimap, pred_detail)
gt_detail = torch.where(boundaries, trimap, gt_matte)
detail_loss = torch.mean(F.l1_loss(pred_boundary_detail, gt_detail))
detail_loss = detail_scale * detail_loss
# calculate the matte loss
pred_boundary_matte = torch.where(boundaries, trimap, pred_matte)
matte_l1_loss = F.l1_loss(pred_matte, gt_matte) + 4.0 * F.l1_loss(pred_boundary_matte, gt_matte)
matte_compositional_loss = F.l1_loss(image * pred_matte, image * gt_matte) \
+ 4.0 * F.l1_loss(image * pred_boundary_matte, image * gt_matte)
matte_loss = torch.mean(matte_l1_loss + matte_compositional_loss)
matte_loss = matte_scale * matte_loss
# calculate the final loss, backward the loss, and update the model
loss = semantic_loss + detail_loss + matte_loss
loss.backward()
optimizer.step()
五、总结
MODnet结构清晰,优秀的训练数据是关键,可惜不开源。