莫烦Pytorch conditional_GAN.py笔记与报错修改

406_conditional_GAN.py代码在pytorch1.5以上版本的报错

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
报错原因:

上网搜索原因,有建议如下错误修改方法1:

 实际上如上修改是错误的,虽然程序能跑通,但是实际运行出的结果时错误的

实际上真实的错误原因是因为 :问题来源主要是opt_D.step()变动了参数,原本pytorch1.4没有这一步的in place 检查。在1.5 版本他们加入了这个检查,所以如果你是1.4版本不会报错,但是1.5版本会报错。

正确的修改方法如下:


for step in range(10000):
    artist_paintings, labels = artist_works_with_labels()           # real painting, label from artist
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS)                      # random ideas
    G_inputs = torch.cat((G_ideas, labels), 1)                      # ideas with labels
    G_paintings = G(G_inputs)                                       # fake painting w.r.t label from G
    D_inputs1 = torch.cat((G_paintings, labels), 1)
    prob_artist1 = D(D_inputs1)

    G_loss = torch.mean(torch.log(1. - prob_artist1))
    opt_G.zero_grad()
    G_loss.backward(retain_graph=True)
    opt_G.step()

    D_inputs0 = torch.cat((artist_paintings, labels), 1)            # all have their labels
    prob_artist0 = D(D_inputs0)                 # D try to increase this prob
    prob_artist1 = D(torch.cat((G_paintings, labels), 1).detach())  # D try to reduce this prob
    D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))

    opt_D.zero_grad()
    D_loss.backward()      # reusing computational graph
    opt_D.step()

    if step % 200 == 0:  # plotting
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
        bound = [0, 0.5] if labels.data[0, 0] == 0 else [0.5, 1]
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + bound[1], c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + bound[0], c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
        plt.text(-.5, 1.7, 'Class = %i' % int(labels.data[0, 0]), fontdict={'size': 13})
        plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.1)

plt.ioff()
plt.show()

运行后结果如下:

 

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

mooyuan天天

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值