x=V(t.arange(0,3),requires_grad=True)
y=x**2+2*x
z=y.sum()
z.backward()
x.grad
运行上述代码会出现
Only Tensors of floating point and complex dtype can require gradients
问题,主要问题是x的类型.将t.arange(0,3)运行并打印得到
x=t.arange(0,3)
tensor([0,1,2])
从上面的结果可以看出tensor的类型是int类型,int类型在tensor中不能反向传播的。所以应该将上述问题的x的类型转变成torch.float类型即可。
即:
x=V(t.arange(0,3).float(),requires_grad=True)