TenserRT(七): TensorRT Python实战

1、生成ONNX

import torch
import UNetTrainer
dev_nb=0
model_file = r'xxx.model'
trainer = UNetTrainer(dev_nb, dim=2)
trainer.load_model(model_file)
device = torch.device('cuda', dev_nb)
input = torch.randn(1, 1, 256, 256).cuda(device, non_blocking=True)
torch.onnx.export(trainer.network, input, 'body_model.onnx',input_names=['input'], output_names=['output'])

2、生成trt engine

trtexec --onnx=body_model.onnx --saveEngine=body_model.engine --fp32

3、engine推理

from typing import Union, Optional, Sequence,Dict,Any

import torch
import tensorrt as trt

class TRTWrapper(torch.nn.Module):
    def __init__(self,engine: Union[str, trt.ICudaEngine], output_names: Optional[Sequence[str]] = None) -> None:
        super().__init__()
        self.engine = engine
        if isinstance(self.engine, str):
            with trt.Logger() as logger, trt.Runtime(logger) as runtime:
                with open(self.engine, mode='rb') as f:
                    engine_bytes = f.read()
                self.engine = runtime.deserialize_cuda_engine(engine_bytes)
        self.context = self.engine.create_execution_context()
        names = [_ for _ in self.engine]
        print('names',names)
        input_names = list(filter(self.engine.binding_is_input, names))
        self._input_names = input_names
        self._output_names = output_names
        print(input_names,output_names)
        if self._output_names is None:
            output_names = list(set(names) - set(input_names))
            self._output_names = output_names

    def forward(self, inputs: Dict[str, torch.Tensor]):
        assert self._input_names is not None
        assert self._output_names is not None
        bindings = [None] * (len(self._input_names) + len(self._output_names))
        profile_id = 0
        for input_name, input_tensor in inputs.items():
            # check if input shape is valid
            # profile = self.engine.get_profile_shape(profile_id, input_name)
            profile = self.engine.get_tensor_profile_shape(input_name,profile_id )
            assert input_tensor.dim() == len( profile[0]), 'Input dim is different from engine profile.'
            for s_min, s_input, s_max in zip(profile[0], input_tensor.shape, profile[2]):
                assert s_min <= s_input <= s_max,  'Input shape should be between '  f'{profile[0]} and {profile[2]}' \
                    + f' but get {tuple(input_tensor.shape)}.'
            idx = int(self.engine.get_binding_index(input_name))
            # All input tensors must be gpu variables
            assert 'cuda' in input_tensor.device.type
            input_tensor = input_tensor.contiguous()
            if input_tensor.dtype == torch.long:
                input_tensor = input_tensor.int()
            self.context.set_binding_shape(idx, tuple(input_tensor.shape))
            bindings[idx] = input_tensor.contiguous().data_ptr()

        # create output tensors
        outputs = {}
        for output_name in self._output_names:
            idx = self.engine.get_binding_index(output_name)
            dtype = torch.float32
            shape = tuple(self.context.get_binding_shape(idx))
            device = torch.device('cuda')
            output = torch.empty(size=shape, dtype=dtype, device=device)
            outputs[output_name] = output
            bindings[idx] = output.data_ptr()
        print('bindings', bindings)

        self.context.execute_async_v2(bindings,  torch.cuda.current_stream().cuda_stream)
        return outputs

import numpy as np
import datetime
import SimpleITK as sitk
from utilities.data_oi import load_cts_image
from UNetTrainer
from pytool.data_convert_2d import resampleImage
########################################################################################
input_fn = r'ct.cts'
ct_volume, origin, spacing = load_cts_image(input_fn,"True")
size = np.array(ct_volume.shape)
itk_img = sitk.GetImageFromArray(ct_volume)
itk_img.SetSpacing(spacing[::-1])
itk_img.SetOrigin(origin[::-1])
resample_itk_img = resampleImage(itk_img, 256)
data = sitk.GetArrayFromImage(resample_itk_img)
sub_ct = data.astype(np.short)
########################################################################################
model_file = r'xxx.model'
trainer = FabiansUNetTrainer(0, dim=2)
trainer.load_model(model_file)
mean_intensity = trainer.intensity_properties[0]['mean'],
std_intensity = trainer.intensity_properties[0]['sd'],
lower_bound = trainer.intensity_properties[0]['percentile_00_5'],
upper_bound = trainer.intensity_properties[0]['percentile_99_5']
########################################################################################
ctr_idx = int(ct_volume.shape[0]*0.5)
input = sub_ct[ctr_idx]
input = np.clip(input, lower_bound, upper_bound)
np_input = (input - mean_intensity) / std_intensity
sitk.WriteImage(sitk.GetImageFromArray(np_input),"img.nii.gz")
np_input = np.expand_dims(np.expand_dims(input,axis=0),axis=0).astype(np.float32)
print('np_input',np_input.shape)
tr_input = torch.from_numpy(np_input).cuda()
print('input',tr_input.shape)
########################################################################################
import onnxruntime
start_time = datetime.datetime.now()

ort_session = onnxruntime.InferenceSession("body_model.onnx")#输入ONNX,获取ONNX Runtime推理器
inname = [input.name for input in ort_session.get_inputs()]
outname = [output.name for output in ort_session.get_outputs()]
print("inputs name:", inname, "|| outputs name:", outname)
ort_inputs = {inname[0]: np_input}#输入值字典,key为张量名,value为numpy类型的张量值
ort_output = ort_session.run(outname, ort_inputs)[0]#网络推理(输出张量名列表,输入值字典),输入输出张量名称要和torch.onnx.export中设置的输入输出名称对应
ort_output = ort_output.transpose((1, 0, 2, 3)).argmax(0).astype(np.uint8)#坐标轴转置,并转换数据类型
print(ort_output.shape)
end_time = datetime.datetime.now()
print("onnxruntime 耗时: {}秒".format(end_time - start_time))
sitk.WriteImage(sitk.GetImageFromArray(ort_output),"ort_output.nii.gz")
# ########################################################################################
start_time = datetime.datetime.now()
###
model = TRTWrapper('body_model.engine', [ '409', '415', '421', '427', '433', 'output'])
trt_output = model(dict(input = tr_input))
###
end_time = datetime.datetime.now()
print("trt 耗时: {}秒".format(end_time - start_time))
########################################################################################
start_time = datetime.datetime.now()
###
dev_nb=0
model_file = r'xxx.model'
trainer = UNetTrainer(dev_nb, dim=2)
trainer.load_model(model_file)
#np_input  = input.detach().cpu().numpy()
ptr_output = trainer.predict(np_input)
###
end_time = datetime.datetime.now()
print("pytorch 耗时: {}秒".format(end_time - start_time))
########################################################################################
ptr_output=ptr_output.argmax(0).astype(np.uint8)
trt_output=trt_output["output"].detach().cpu().numpy().transpose((1, 0, 2, 3)).argmax(0).astype(np.uint8)
print(trt_output.shape)
print(ptr_output.shape)
sitk.WriteImage(sitk.GetImageFromArray(ptr_output),"ptr_output.nii.gz")
sitk.WriteImage(sitk.GetImageFromArray(trt_output),"trt_output.nii.gz")
########################################################################################
from evaluation.metrics import ConfusionMatrix, dice, jaccard, hausdorff_distance, hausdorff_distance_95
confusion_matrix = ConfusionMatrix(ptr_output, trt_output)
dice_score = dice(confusion_matrix=confusion_matrix)
jaccard_score = jaccard(confusion_matrix=confusion_matrix)
hd_score = hausdorff_distance(confusion_matrix=confusion_matrix, voxel_spacing=spacing, connectivity=1)
hd95_score = hausdorff_distance_95(confusion_matrix=confusion_matrix, voxel_spacing=spacing, connectivity=1)
print("dice_score",dice_score)
print("jaccard_score",jaccard_score)
print("hd_score",hd_score)
print("hd95_score",hd95_score)
########################################################################################
assert np.allclose(ptr_output, trt_output)#判断torch值和onnxruntime值一致,assert断言返回值,allclose判断两个张量是否一致

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值