1.需要第三方库torch2trt;pip install torch2trt;
2.from torch2trt import torch2trt 导入api
3.代码示例
import torch
from torch2trt import torch2trt
from torch2trt import TRTModule
import datetime
model = torch.load('best.pt').cuda()
trx = torch.ones((1, 3, 608, 608)).cuda()
model_trt = torch2trt(model , [trx])
torch.save(model_trt.state_dict(), 'best_trt.pth')
model_trt = TRTModule()
model_trt.load_state_dict(torch.load('best_trt.pth'))
y_trt = model_trt(trx)
本人yolov5模型转换成功(转trt文件forward方法中最好别出现非pytorch方法,否则可能会转换失败,博主这里进行了预处理,转换成功);转换为int8格式的模型文件;并且成功运行,在detect的时候,通过外部编写detect处理代码,代码如下:
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
def detect(pre)