有两种将PyTorch模型转换为Torch脚本的方法。第一种是通过trace转换,第二种是通过script转换,trace比script简单,但只适合结构固定的网络模型,即forward中没有控制流的情况,因为trace只会保存运行时实际走的路径。如果forward函数中有控制流(IfElse、Switch、While、DynamicRNN、StaticRNN),需要用script方式实现。 这里使用的是trace进行转换,第二种还没有开始研究,时间有限,后面如果用到有控制流的模型,应该要去研究一哈。
由于项目要求需要提供在windows和ubantu的库,所以两者都要分别.pt文件!这里需要尤其注意,划重点!ubantu下导出的.pt只能在ubantu下使用,windows下导出的模型只能windows下用!!!下附导出代码,好像也没啥不太一样哈哈,既然贴上了就不删了~
Ubantu
import torch
import torchvision
import torch.nn as nn
from torch.utils import tensorboard
from models import *
classes=7
device=torch.device("cuda")
checkpoint=torch.load("best_model_FCN8_0224_240.pth")
model=FCN8(num_classes=classes)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint.keys():