为了进一步使用c++调用deeplabv3+模型,使用trace将pytorch训练生成的.pth格式转为.pt
参考:https://github.com/shanson123/ORB_SLAM2_DeeplabV3/blob/master/DeeplabV3/create_deeplabv3.py
在predict.py文件中添加:
with torch.no_grad():
model = model.eval()
for img_path in tqdm(image_files):
ext = os.path.basename(img_path).split('.')[-1]
img_name = os.path.basename(img_path)[:-len(ext)-1]
img = Image.open(img_path).convert('RGB')
img = transform(img).unsqueeze(0) # To tensor of NCHW
img = img.to(device)
pred = model(img).max(1)[1].cpu().numpy()[0] # HW
colorized_preds = decode_fn(pred).astype('uint8')
colorized_preds = Image.fromarray(colorized_preds)
if opts.save_val_results_to:
colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png'))
#pth转pt
traced_model = torch.jit.trace(model.module, img.to(device))
traced_model.save("DeeplabV3plus.pt")
注意,如果写成 traced_model = torch.jit.trace(model, img.to(device)),会出现下图的报错:
Could not export Python function call ‘Scatter’.