自己的代码遇到了梯度回传为0的问题,检查了很久,最后发现是torch.round()对张量操作之后导致梯度断裂,因为round是返回的一个新的张量,之前没有注意到,记录一下。
函数详解:
torch.round(input, out=None)
说明:返回一个新张量,将输入input张量的每个元素舍入到最近的整数。
参数:
input(Tensor):输入张量
out(Tensor,可选):输出张量
自己的代码遇到了梯度回传为0的问题,检查了很久,最后发现是torch.round()对张量操作之后导致梯度断裂,因为round是返回的一个新的张量,之前没有注意到,记录一下。
函数详解:
torch.round(input, out=None)
说明:返回一个新张量,将输入input张量的每个元素舍入到最近的整数。
参数:
input(Tensor):输入张量
out(Tensor,可选):输出张量