(TorchScript应用) PyTorch模型转换为Torch脚本的代码出错。
出现原因:想将pytorch训练的.pth文件转成C++能处理的.pt文件。用的TorchScript的方法。具体代码如下:
import argparse
import cv2
import torchvision
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from unet import NestedUNet
from unet import UNet
from utils.dataset import BasicDataset
from config import UNetConfig
cfg = UNetConfig()
device = torch.device('cpu')
model = eval(cfg.model)(cfg)
path = 'data/checkpoints/epoch_9.pth'#自己模型训练的结果
model.load_state_dict(torch.load(path, map_location=device))
model.to(device=device)
print(model)img_path = 'data/00000.jpg' #自己模型准备用的数据
img = cv2.imread(img_path)
imgXX = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
img1 = torch.from_numpy(BasicDataset.preprocess(imgXX, cfg.scale))
img1 = img1.unsqueeze(0)
img1 = img1.to(device=device, dtype=torch.float32)
traced_script_module = torch.jit.trace(model,img1)#报错的代码位置
traced_script_module.save("torch_script_eval.pt")
报错的内容为:Only tensors or tuples of tensors can be output from traced functions
返回的是字典,是不支持的。想办法把返回的字典变成 tensors,就可以了。
怀疑model这个模型的返回值非法了。找到model模型定义的返回值位置
def forward(self, input):
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))
x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
if self.deepsupervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
return [output1, output2, output3, output4)] #按程序设定 走的是这个分支。 显然[ ]不符合。
else:
output = self.final(x0_4)
return output
将其改为return (output1, output2, output3, output4)) #报错解决了。
总结:raced_script_module = torch.jit.trace(model,img1)这个位置报这类错,就怀疑返回值问题,然后一步步查。
上面这个改可能不符合代码原意了。再看看