说明
我参考了一个开源的人像语义分割项目mobile_phone_human_matting,这个项目提供了预训练模型,我想要将该模型固化,然后转换格式后在嵌入式端使用。
该项目保存模型的代码如下:
lastest_out_path = "{}/ckpt_lastest.pth".format(self.save_dir_model)
torch.save({
'epoch': epoch,
'state_dict': model.state_dict(),
}, lastest_out_path)
转换代码
上面代码保存了state_dict, 所以保存的文件中是不含模型结构的,固化时需要从代码构造网络结构。好在项目是完全开源,将原项目下的model目录拷贝过来就行。
另外不能忘记调用eval() 来固化参数。
完整的转换代码如下:
import torch
from model import segnet
ckptfile="./ckpt_lastest.pth"
savedfile="./human_seg.pt"
model = segnet.SegMattingNet()
device = torch.device('cpu')
ckpt = torch.load(ckptfile, map_location=device )
model.load_state_dict(ckpt['state_dict'])
model.eval() #这一步会将参数固化,不能省。否则会报AssertionError('batchnorm with training is not support. Please set model.eval() before export.')
x = torch.rand(1,3,256,256)
ts = torch.jit.trace(model, x)
ts.save(savedfile)