Pytorch医学图像分割网络预测部署

pytorch预测

import numpy as np
import SimpleITK as sitk
import torch
import torch.nn as nn

class ToyNet(nn.Module):
    def __init__(self):
        super(ToyNet, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv3d(in_channels=1,out_channels=4,kernel_size=3,),
            torch.nn.ReLU()
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.ConvTranspose3d(in_channels=4,out_channels=2,kernel_size=3,),
        )
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

# 读取文件
image = sitk.ReadImage('sample.mha')
inputTensor = torch.from_numpy(sitk.GetArrayFromImage(image)).float()
inputTensor = inputTensor[None, None,:, :, :]
print(inputTensor.shape)
print(inputTensor)

# 预测
net = ToyNet()
# net.load_state_dict(torch.load("xx.pth"))
net.eval()

outputTensor = net(inputTensor)
outputTensor = torch.squeeze(torch.argmax(torch.nn.functional.softmax(outputTensor, 1),1))
print(outputTensor.shape)
print(outputTensor)

# 写文件
outputImage = sitk.GetImageFromArray(outputTensor.detach().numpy().astype(np.uint8))
sitk.WriteImage(outputImage,'PyOutput.mha')

# 模型转换
traced_script_module = torch.jit.trace(net, inputTensor)
traced_script_module.save("ToyNet.pt")

libtorch预测

#include <iostream>
#include <torch/script.h>
#include <itkMetaImageIO.h>
#include <itkImageFileReader.h>
#include <itkImageFileWriter.h>

using namespace std;


void main()
{
	using ImageType = itk::Image<float, 3>;

	typedef itk::ImageFileReader< ImageType >  ReaderType;
	auto reader = ReaderType::New();
	reader->SetFileName("sample.mha");
	reader->SetImageIO(itk::MetaImageIO::New());
	reader->Update();

	auto inputImage = reader->GetOutput();

	auto module = torch::jit::load("ToyNet.pt");

	// itk image to tensor
	torch::Tensor inputTensor;
	inputTensor = torch::from_blob(
		inputImage->GetBufferPointer(),
		{
			1,  1, 6, 5, 4
		}).toType(torch::kFloat32);


	// image tensor to batch
	std::vector<torch::jit::IValue> imageTensorBatch;
	imageTensorBatch.push_back(inputTensor);

	torch::Tensor outputTensorBatch = torch::argmax(torch::softmax(module.forward(imageTensorBatch).toTensor(), 1), 1);
	outputTensorBatch = outputTensorBatch.squeeze(0);
	cout << outputTensorBatch.sizes() << endl;
	cout << outputTensorBatch << endl;

	ImageType::Pointer duplicate = ImageType::New();
	duplicate->SetRegions(inputImage->GetBufferedRegion());
	duplicate->Allocate();

	//float* outBuffer = output.data<float>();
	__int64* outBuffer = outputTensorBatch.data<__int64>();
	itk::ImageRegionIterator<ImageType> outIt(duplicate, duplicate->GetRequestedRegion());
	outIt.GoToBegin();
	while (!outIt.IsAtEnd())
	{
		outIt.Set((*outBuffer));
		++outIt;
		outBuffer++;
	}

	typedef itk::ImageFileWriter< ImageType >  WriterType;
	auto writer = WriterType::New();

	writer->SetFileName("CppOutput.mha");
	writer->SetInput(duplicate);
	writer->SetImageIO(itk::MetaImageIO::New());
	writer->Write();
}

结果对比

  • pytorch在这里插入图片描述
  • libtorch在这里插入图片描述
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值