利用perceptual_loss感知损失获得更好的图片重建效果
传统的MSEloss在图像重建领域会带来图像高频信息缺失的问题,导致生成的图片出现模糊。感知损失通过对卷积提取的高层信息进行比较,很好的缓解了上述问题,在此提供一个独立的perceptual_loss代码,方便初学者在训练过程中使用
def build_net(ntype,nin,nwb=None,name=None):
if ntype=='conv':
return tf.nn.relu(tf.nn.conv2d(nin,nwb[0],strides=[1,1,1,1],padding='SAME',name=name)+nwb[1])
elif ntype=='pool':
return tf.nn.avg_pool(nin,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
def get_weight_bias(vgg_layers,i):
weights=vgg_layers[i][0][0][2][0][0]
weights=tf.constant(weights)
bias=vgg_layers[i][0][0][2][0][1]
bias=tf.constant(np.reshape(bias,(bias.size)))
return weights,bias
vgg_path=scipy.io.loadmat('your vgg19path')
print("[i] Loaded pre-trained vgg19 parameters")
# build VGG19 to load pre-trained parameters
def build_vgg19(input,reuse=False):
with tf.variable_scope("vgg19"):
if reuse:
tf.get_variable_scope().reuse_variables()
net={}
vgg_layers=vgg_path['layers'][0]
net['input']=input
net['conv1_1']=build_net('conv',net['input'],get_weight_bias(vgg_layers,0),name='vgg_conv1_1')
net['conv1_2']=build_net('conv',net['conv1_1'],get_weight_bias(vgg_layers,2),name='vgg_conv1_2')
net['pool1']=build_net('pool',net['conv1_2'])
net['conv2_1']=build_net('conv',net['pool1'],get_weight_bias(vgg_layers,5),name='vgg_conv2_1')
net['conv2_2']=build_net('conv',net['conv2_1'],get_weight_bias(vgg_layers,7),name='vgg_conv2_2')
net['pool2']=build_net('pool',net['conv2_2'])
net['conv3_1']=build_net('conv',net['pool2'],get_weight_bias(vgg_layers,10),name='vgg_conv3_1')
net['conv3_2']=build_net('conv',net['conv3_1'],get_weight_bias(vgg_layers,12),name='vgg_conv3_2')
net['conv3_3']=build_net('conv',net['conv3_2'],get_weight_bias(vgg_layers,14),name='vgg_conv3_3')
net['conv3_4']=build_net('conv',net['conv3_3'],get_weight_bias(vgg_layers,16),name='vgg_conv3_4')
net['pool3']=build_net('pool',net['conv3_4'])
net['conv4_1']=build_net('conv',net['pool3'],get_weight_bias(vgg_layers,19),name='vgg_conv4_1')
net['conv4_2']=build_net('conv',net['conv4_1'],get_weight_bias(vgg_layers,21),name='vgg_conv4_2')
net['conv4_3']=build_net('conv',net['conv4_2'],get_weight_bias(vgg_layers,23),name='vgg_conv4_3')
net['conv4_4']=build_net('conv',net['conv4_3'],get_weight_bias(vgg_layers,25),name='vgg_conv4_4')
net['pool4']=build_net('pool',net['conv4_4'])
net['conv5_1']=build_net('conv',net['pool4'],get_weight_bias(vgg_layers,28),name='vgg_conv5_1')
net['conv5_2']=build_net('conv',net['conv5_1'],get_weight_bias(vgg_layers,30),name='vgg_conv5_2')
return net
def compute_l1_loss(input, output):
return tf.reduce_mean(tf.abs(input-output))
def compute_percep_loss(input, output, reuse=False):
vgg_real=build_vgg19(output,reuse=reuse)
vgg_fake=build_vgg19(input,reuse=True)
p0=compute_l1_loss(vgg_real['input'],vgg_fake['input'])
p1=compute_l1_loss(vgg_real['conv1_2'],vgg_fake['conv1_2'])/2.6
p2=compute_l1_loss(vgg_real['conv2_2'],vgg_fake['conv2_2'])/4.8
p3=compute_l1_loss(vgg_real['conv3_2'],vgg_fake['conv3_2'])/3.7
p4=compute_l1_loss(vgg_real['conv4_2'],vgg_fake['conv4_2'])/5.6
p5=compute_l1_loss(vgg_real['conv5_2'],vgg_fake['conv5_2'])*10/1.5
return p0+p1+p2+p3+p4+p5
- 将上述代码加入你的模型中,并在loss中调用compute_percep_loss即可,VGG模型可以在此处下载: VGG19.