目录
数据类型转换
aaa= torch.Tensor(3, 5)
bbb= torch.IntTensor(2,3)
print aaa.type_as(bbb)
模型转换
转化为jit模式:
model = UNet(3, 1)
modelname = 'ckpt_e_50.pth'
ckpt = torch.load(opt.pretrain + modelname)
model.load_state_dict(ckpt['state_dict'], strict=False)
model.eval()
example = torch.rand(1, 3, 240, 320)
example1 = torch.rand(1, 1, 30, 40)
traced_script_module = torch.jit.trace(model, (example, example1))
traced_script_module.save("model.pt")
<