错误代码:反向传播后输出factor参数的grad为None
factor = torch.ones(num, requires_grad=True)
self.factor = torch.nn.Parameter(factor).cuda()
错误原因:self.factor经过一次.cuda()操作后就不再是叶子结点了。
修改后代码:
factor = torch.ones(num, requires_grad=True).cuda()
self.factor = torch.nn.Parameter(factor)