前言
最近这段时间需要使用clam模型进行分类训练,但是在特征提取部分,我不想重新训练特征提取器,我想将之前用小图训练出来的resnet模型直接进行替换,而根据模型载入部分的要求,要求使用jit保存的模型,我就想使用torch.save保存的pth模型能否直接转换成torch.jit.save的模型。
模型转换代码
import torch
model = torch.load("/model.pth",map_location=torch.device("cpu"))#之前保存的训练模型
model.eval()#切换到eval()
example = torch.rand(1, 3, 224, 224)#生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("/model.pt")#转换成jit的模型