参考:https://blog.csdn.net/omnispace/article/details/54942668
上面这篇博客讲的很好!
PS:
(1)wgan中的weight cliping后面又被升级为gradient penalty;
参考:http://www.sohu.com/a/138121777_494939
代码:
from torch.autograd import grad #gradient penalty , autograd way LAMBDA_GRAD_PENALTY = 1.0 alpha = torch.rand(BATCH_SIZE, 1, 1, 1).cuda() #pred_penalty是生成的分布,D_gt_v是真实分布 differences = pred_penalty - D_gt_v interpolates = D_gt_v + (alpha * differences) D_interpolates = model_D(interpolates) gradients = grad(outputs=D_interpolates, inputs=interpolates, grad_outputs=torch.ones(D_interpolates.size()).cuda(), create_graph=False, retain_graph=True, only_inputs=True)[0] gradient_penalty = torch.mean(torch.sqrt(torch.sum((gradients - 1) ** 2 , dim = (1 , 2 , 3)))) * LAMBDA_GRAD_PENALTY loss_D += gradient_penalty