Only tensors or tuples of tensors can be output from traced functions错误解决

(TorchScript应用) PyTorch模型转换为Torch脚本的代码出错。

出现原因:想将pytorch训练的.pth文件转成C++能处理的.pt文件。用的TorchScript的方法。具体代码如下:

import argparse
import cv2
import torchvision

import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from unet import NestedUNet
from unet import UNet
from utils.dataset import BasicDataset
from config import UNetConfig

cfg = UNetConfig()
device = torch.device('cpu')
model = eval(cfg.model)(cfg)
path = 'data/checkpoints/epoch_9.pth'#自己模型训练的结果
model.load_state_dict(torch.load(path, map_location=device))
model.to(device=device)
print(model)img_path = 'data/00000.jpg' #自己模型准备用的数据
img = cv2.imread(img_path)
imgXX = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
img1 = torch.from_numpy(BasicDataset.preprocess(imgXX, cfg.scale))
img1 = img1.unsqueeze(0)
img1 = img1.to(device=device, dtype=torch.float32)
traced_script_module = torch.jit.trace(model,img1)#报错的代码位置

traced_script_module.save("torch_script_eval.pt")

报错的内容为:Only tensors or tuples of tensors can be output from traced functions

返回的是字典,是不支持的。想办法把返回的字典变成 tensors,就可以了。

怀疑model这个模型的返回值非法了。找到model模型定义的返回值位置

def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))

        if self.deepsupervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4)] #按程序设定 走的是这个分支。 显然[ ]不符合。

        else:
            output = self.final(x0_4)
            return output

将其改为return (output1, output2, output3, output4)) #报错解决了。

总结:raced_script_module = torch.jit.trace(model,img1)这个位置报这类错,就怀疑返回值问题,然后一步步查。

上面这个改可能不符合代码原意了。再看看

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值