记录一下pytorch加载模型的一些细节
正确过程
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet()
model.load_state_dict(torch.load('last.pth'))
model.to(device=device)
一些细节
1.注意模型的实例化UNet()
2.若在cpu加载模型,一般需要在torch.load中加上map_location=torch.device('cpu')
,即
model.load_state_dict(torch.load('last.pth', map_location=torch.device('cpu')))