pytorch预测
import numpy as np
import SimpleITK as sitk
import torch
import torch.nn as nn
class ToyNet(nn.Module):
def __init__(self):
super(ToyNet, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv3d(in_channels=1,out_channels=4,kernel_size=3,),
torch.nn.ReLU()
)
self.conv2 = torch.nn.Sequential(
torch.nn.ConvTranspose3d(in_channels=4,out_channels=2,kernel_size=3,),
)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
image = sitk.ReadImage('sample.mha')
inputTensor = torch.from_numpy(sitk.GetArrayFromImage(image)).float()
inputTensor = inputTensor[None, None,:, :, :]
print(inputTensor.shape)
print(inputTensor)
net = ToyNet()
net.eval()
outputTensor = net(inputTensor)
outputTensor = torch.squeeze(torch.argmax(torch.nn.functional.softmax(outputTensor, 1),1))
print(outputTensor.shape)
print(outputTensor)
outputImage = sitk.GetImageFromArray(outputTensor.detach().numpy().astype(np.uint8))
sitk.WriteImage(outputImage,'PyOutput.mha')
traced_script_module = torch.jit.trace(net, inputTensor)
traced_script_module.save("ToyNet.pt")
libtorch预测
#include <iostream>
#include <torch/script.h>
#include <itkMetaImageIO.h>
#include <itkImageFileReader.h>
#include <itkImageFileWriter.h>
using namespace std;
void main()
{
using ImageType = itk::Image<float, 3>;
typedef itk::ImageFileReader< ImageType > ReaderType;
auto reader = ReaderType::New();
reader->SetFileName("sample.mha");
reader->SetImageIO(itk::MetaImageIO::New());
reader->Update();
auto inputImage = reader->GetOutput();
auto module = torch::jit::load("ToyNet.pt");
torch::Tensor inputTensor;
inputTensor = torch::from_blob(
inputImage->GetBufferPointer(),
{
1, 1, 6, 5, 4
}).toType(torch::kFloat32);
std::vector<torch::jit::IValue> imageTensorBatch;
imageTensorBatch.push_back(inputTensor);
torch::Tensor outputTensorBatch = torch::argmax(torch::softmax(module.forward(imageTensorBatch).toTensor(), 1), 1);
outputTensorBatch = outputTensorBatch.squeeze(0);
cout << outputTensorBatch.sizes() << endl;
cout << outputTensorBatch << endl;
ImageType::Pointer duplicate = ImageType::New();
duplicate->SetRegions(inputImage->GetBufferedRegion());
duplicate->Allocate();
__int64* outBuffer = outputTensorBatch.data<__int64>();
itk::ImageRegionIterator<ImageType> outIt(duplicate, duplicate->GetRequestedRegion());
outIt.GoToBegin();
while (!outIt.IsAtEnd())
{
outIt.Set((*outBuffer));
++outIt;
outBuffer++;
}
typedef itk::ImageFileWriter< ImageType > WriterType;
auto writer = WriterType::New();
writer->SetFileName("CppOutput.mha");
writer->SetInput(duplicate);
writer->SetImageIO(itk::MetaImageIO::New());
writer->Write();
}
结果对比
- pytorch
- libtorch