406_conditional_GAN.py代码在pytorch1.5以上版本的报错
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
报错原因:
上网搜索原因,有建议如下错误修改方法1:
实际上如上修改是错误的,虽然程序能跑通,但是实际运行出的结果时错误的
实际上真实的错误原因是因为 :问题来源主要是opt_D.step()变动了参数,原本pytorch1.4没有这一步的in place 检查。在1.5 版本他们加入了这个检查,所以如果你是1.4版本不会报错,但是1.5版本会报错。
正确的修改方法如下:
for step in range(10000):
artist_paintings, labels = artist_works_with_labels() # real painting, label from artist
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
G_inputs = torch.cat((G_ideas, labels), 1) # ideas with labels
G_paintings = G(G_inputs) # fake painting w.r.t label from G
D_inputs1 = torch.cat((G_paintings, labels), 1)
prob_artist1 = D(D_inputs1)
G_loss = torch.mean(torch.log(1. - prob_artist1))
opt_G.zero_grad()
G_loss.backward(retain_graph=True)
opt_G.step()
D_inputs0 = torch.cat((artist_paintings, labels), 1) # all have their labels
prob_artist0 = D(D_inputs0) # D try to increase this prob
prob_artist1 = D(torch.cat((G_paintings, labels), 1).detach()) # D try to reduce this prob
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
opt_D.zero_grad()
D_loss.backward() # reusing computational graph
opt_D.step()
if step % 200 == 0: # plotting
plt.cla()
plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
bound = [0, 0.5] if labels.data[0, 0] == 0 else [0.5, 1]
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + bound[1], c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + bound[0], c='#FF9359', lw=3, label='lower bound')
plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 13})
plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
plt.text(-.5, 1.7, 'Class = %i' % int(labels.data[0, 0]), fontdict={'size': 13})
plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.1)
plt.ioff()
plt.show()
运行后结果如下: