WGAN讲解

参考:https://blog.csdn.net/omnispace/article/details/54942668

上面这篇博客讲的很好!

PS:

(1)wgan中的weight cliping后面又被升级为gradient penalty;

参考:http://www.sohu.com/a/138121777_494939

代码:

from torch.autograd import grad
#gradient penalty , autograd way
LAMBDA_GRAD_PENALTY = 1.0
alpha = torch.rand(BATCH_SIZE, 1, 1, 1).cuda()
#pred_penalty是生成的分布,D_gt_v是真实分布
differences = pred_penalty - D_gt_v 
interpolates = D_gt_v + (alpha * differences)
D_interpolates = model_D(interpolates)
gradients = grad(outputs=D_interpolates, inputs=interpolates, grad_outputs=torch.ones(D_interpolates.size()).cuda(), create_graph=False, retain_graph=True, only_inputs=True)[0]
gradient_penalty = torch.mean(torch.sqrt(torch.sum((gradients - 1) ** 2 , dim = (1 , 2 , 3)))) * LAMBDA_GRAD_PENALTY
loss_D += gradient_penalty

 

转载于:https://www.cnblogs.com/zf-blog/p/10571722.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值