目录
Resolution_Image_Harmonization_via_Collaborative_Dual_Transformations_CVPR_2022_paper
Target
image harmonization aims to adjust the foreground to make it compatible with the background.
Problem:
Conventional image harmonization methods learn global RGB-to-RGB transformation which could effort lessly scale to high resolution, but ignore diverse local context. Recent deep learning methods learn the dense pixel-to-pixel transformation which could generate harmonious outputs, but are highly constrained in low resolution.
This paper:
High-resolution image harmonization network with Collaborative Dual Transformation (CDTNet) to combine pixel-to-pixel transformation and RGB-to-RGB transformation coherently in an end-to-end network.
Contribution:
- We unify pixel-to-pixel transformation and color-to-color transformation coherently in an end-to-end net-work named CDTNet.
- • Extensive experiments demonstrate that our CDTNet achieves state-of-the-art results with less resource consumption.
Method
输入高清图和mask,输出增强高清图;
现将高清图和mask降低分辨率,使用映射得到低分辨率增强图;
之后根据mask和中间特征融合到LUT中增强高清图,最后将四者融合,并使用精调模块,得到最终结果!
class BaseGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf, norm, opt):
super(BaseGenerator, self).__init__()
self.device = opt.device
self.LUT0 = Generator3DLUT_identity(dim=64)
self.LUT1 = Generator3DLUT_zero(dim=64)
self.LUT2 = Generator3DLUT_zero(dim=64)
self.LUT3 = Generator3DLUT_zero(dim=64)
self.linear_f = torch.nn.Linear(256*8*8, 256, bias=True)
self.linear_b = torch.nn.Linear(256*8*8, 256, bias=True)
self.linear_coef = torch.nn.Linear(512, 4, bias=True)
self.refine0 = ConvBlock(39, 10, kernel_size=3, stride=1, padding=1)
self.refine1 = ConvBlock(10, 5, kernel_size=3, stride=1, padding=1)
self.conv_attention = nn.Conv2d(5, 1, kernel_size=1)
self.to_rgb = nn.Conv2d(5, 3, kernel_size=1)
self.trilinear_ = TrilinearInterpolation()
def forward(self, out_lr_pix, F_map, B_map, F_dec, comp_hr, mask_hr, train_data):
f_f = self.linear_f(F_map.view(F_map.shape[0], -1))
f_b = self.linear_b(B_map.view(B_map.shape[0], -1))
coef = self.linear_coef(torch.cat((f_f, f_b), dim=1))
new_img = (comp_hr*mask_hr).permute(1,0,2,3).contiguous()
gen_A0 = self.LUT0(new_img)
gen_A1 = self.LUT1(new_img)
gen_A2 = self.LUT2(new_img)
gen_A3 = self.LUT3(new_img)
combine_A = new_img.new(new_img.size())
for b in range(new_img.size(1)):
combine_A[:,b,:,:] = coef[b,0] * gen_A0[:,b,:,:] + coef[b,1] * gen_A1[:,b,:,:] + coef[b,2] * gen_A2[:,b,:,:] + coef[b,3] * gen_A3[:,b,:,:] #+ pred[b,4] * gen_A4[:,b,:,:]
if not train_data:
_, combine_A = self.trilinear_(combine_A, new_img)
out_hr_rgb = combine_A.permute(1,0,2,3) + comp_hr*(1-mask_hr) #get the [batch_size,3,width,height] combined image
out_lr_pix = F.interpolate(out_lr_pix, size=comp_hr.size(2))
F_dec = F.interpolate(F_dec, size=comp_hr.size(2))
out_hr = torch.cat((out_lr_pix, out_hr_rgb, mask_hr, F_dec), dim=1) #3+3+1+32
out_hr = self.refine0(out_hr)
out_hr = self.refine1(out_hr)
attention_map = torch.sigmoid(3.0 * self.conv_attention(out_hr))
out_hr = attention_map * comp_hr + (1.0 - attention_map) * self.to_rgb(out_hr)
return out_hr_rgb, out_hr
Pixel to pixel
Unet进行特征升维和降维,提取中间特征和最后输出低分辨率图;
rgb to rgb
使用3D LUT结果输出rgb增强结果;
refined module
融合低分辨率,rgb增强图,高分辨率输入和mask,使用refined module增强最终结果!