1.图像梯度差loss(image gradient difference loss)函数
阅读文献看到梯度差损失函数,在github上搜索,可以查到相关代码,做此记录:
def loss_gradient_difference(real_image,generated): # b x c x h x w
true_x_shifted_right = real_image[:,:,1:,:]# 32 x 3 x 255 x 256
true_x_shifted_left = real_image[:,:,:-1,:]
true_x_gradient = torch.abs(true_x_shifted_left - true_x_shifted_right)
generated_x_shift_right = generated[:,:,1:,:]# 32 x 3 x 255 x 256
generated_x_shift_left = generated[:,:,:-1,:]
generated_x_griednt = torch.abs(generated_x_shift_left - generated_x_shift_right)
difference_x = true_x_gradient - generated_x_griednt
loss_x_gradient = (torch.sum(difference_x)**2)/2 # tf.nn.l2_loss(true_x_gradient - generated_x_gradient)
true_y_shifted_right = real_image[:,:,:,1:]
true_y_shifted_left = real_image[:,:,:,:-1]
true_y_gradient = torch.abs(true_y_shifted_left - true_y_shifted_right)
generated_y_shift_right = generated[:,:,:,1:]
generated_y_shift_left = generated[:,:,:,:-1]
generated_y_griednt = torch.abs(generated_y_shift_left - generated_y_shift_right)
difference_y = true_y_gradient - generated_y_griednt
loss_y_gradient = (torch.sum(difference_y)**2)/2 # tf.nn.l2_loss(true_y_gradient - generated_y_gradient)
igdl = loss_x_gradient + loss_y_gradient
return igdl
后续新的损失函数在此继续补充。
参考网址:https://github.com/Zxl19990529/UGAN-pytorch/blob/40d4842c768e9aaea9f57c8c832c531a6a9d85bb/utils.py
2. Hinge loss 是对地球移动距离的一种拓展
Hinge loss 最初是SVM中的概念,其基本思想是让正例和负例之间的距离尽量大,后来在Geometric GAN中,被迁移到GAN:
对于D来说,只有当D(x) < 1 的正向样本,以及D(G(z)) > -1的负样本才会对结果产生影响。也就是说,只有一些没有被合理区分的样本,才会对梯度产生影响。这种方法可以使训练更加稳定。
def __call__(self, outputs, is_real, is_disc=None):
if self.type == 'hinge':
if is_disc:
if is_real:
outputs = -outputs
return self.criterion(1 + outputs).mean()
else:
return (-outputs).mean()
else:
labels = (self.real_label if is_real else self.fake_label).expand_as(outputs)
loss = self.criterion(outputs, labels)
return loss
3. 更多loss
https://github.com/tariktemur/RGAN-and-RaGAN-PyTorch/blob/master/utils.py
https://github.com/VainF/pytorch-msssim
total variation loss:
https://github.com/grorge123/project/blob/dc91f2a219/face2music/magenta/models/image_stylization/learning.py