AMD 的显卡问题
amd的老显卡,linux平台早不更新了,建议win平台下面用 torch_directml 搞定
pip install torch_directml
import torch_directml
device = torch_directml.device()
剩下的内容和cuda一样了。
没有梯度的错误
element 0 of tensors does not require grad and does not have a grad_fn
这个问题太恶心了,不过问题还是在于代码不规范,没有理解到torch的工作机制。错误出现在保存后,再load,结果就出错了。翻阅了很多地方,才知道是to(device)不能乱用,会把require_grad 这个参数给丢掉。仔细检查代码,发现,load 的时候,正确的代码应该是。
checkpoint = torch.load(model_path, map_location=device)
model = MyNet(input_dim)
model &