pytorch版本Deeplabv3+网络模型格式转换(pth转pt)

为了进一步使用c++调用deeplabv3+模型,使用trace将pytorch训练生成的.pth格式转为.pt

参考:https://github.com/shanson123/ORB_SLAM2_DeeplabV3/blob/master/DeeplabV3/create_deeplabv3.py

在predict.py文件中添加:

    with torch.no_grad():
        model = model.eval()
        for img_path in tqdm(image_files):
            ext = os.path.basename(img_path).split('.')[-1]
            img_name = os.path.basename(img_path)[:-len(ext)-1]
            img = Image.open(img_path).convert('RGB')
            img = transform(img).unsqueeze(0) # To tensor of NCHW
            img = img.to(device)
            
            pred = model(img).max(1)[1].cpu().numpy()[0] # HW
            colorized_preds = decode_fn(pred).astype('uint8')
            colorized_preds = Image.fromarray(colorized_preds)
            if opts.save_val_results_to:
                colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png'))

        #pth转pt
        traced_model = torch.jit.trace(model.module, img.to(device))
        traced_model.save("DeeplabV3plus.pt")

注意,如果写成 traced_model = torch.jit.trace(model, img.to(device)),会出现下图的报错:
Could not export Python function call ‘Scatter’.
请添加图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值