device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
pretrain = torch.load('./weights/decom_net.pth')#默认加载到gpu
#pretrain = torch.load('./weights/decom_net.pth',map_location='cpu') #加载到cpu
model.load_state_dict(pretrain)
加载网络模型到cpu/gpu
最新推荐文章于 2023-06-15 22:45:56 发布