最近在做深度学习医疗图像分割的onnx部署,发现输出的resample图像和结果不对。检查之后发现是忘了处理itk,tensor维度不同的问题。记录一下。
1. Torch Tensor 维度顺序
在 PyTorch(torch::Tensor)中,图像张量的维度顺序通常为:
- N: Batch Size (可选)
- C: Channel
- D/H/W: Depth/Height/Width(在 3D 图像中是 Depth/Height/Width)
因此,对于 3D 图像,维度顺序通常是:
- NCHW: Batch, Channels, Height, Width
- NCDHW: Batch, Channels, Depth, Height, Width
2. ITK 维度顺序
在 ITK 中,图像的维度顺序通常是:
- XYZ: X (宽度), Y (高度), Z (深度)
3. 代码部分
3.1 ITK 读取nii图像
std::tuple<std::vector<float>, ImageType::SpacingType, ImageType::PointType, ImageType::SizeType, ImageType::DirectionType> ITKLoadNiftiImage(const std::string& filename) {
auto reader = itk::ImageFileReader<ImageType>::New();
auto imageIO = itk::NiftiImageIO::New();
reader->SetImageIO(imageIO);
reader->SetFileName(filename);
try {
reader->Update();
ImageType::Pointer image = reader->GetOutput();
ImageType::SizeType size = image->GetLargestPossibleRegion().GetSize();
ImageType::SpacingType spacing = image->GetSpacing();
ImageType::PointType origin = image->GetOrigin();
ImageType::DirectionType direction = image->GetDirection();
std::vector<float> imageData(size[0] * size[1] * size[2]);
std::copy(image->GetBufferPointer(), image->GetBufferPointer() + imageData.size(), imageData.begin());
// 打印 spacing 和 origin
std::cout << "Spacing: " << spacing[0] << " " << spacing[1] << " " << spacing[2] << std::endl;
std::cout << "Origin: " << origin[0] << " " << origin[1] << " " << origin[2] << std::endl;
std::cout << "Dims: " << " size 0 " << size[0] << " size 1 " << size[1] << " size 2 " << size[2] << std::endl;
return std::make_tuple(imageData, spacing, origin, size, direction);
}
catch (itk::ExceptionObject& excp) {
std::cerr << "ExceptionObject caught: " << excp << std::endl;
throw;
}
}
3.2 ITK image data转换为Tensor 注意维度转换
torch::Tensor ConvertToTensor(const std::vector<float>& imageData, const itk::Size<3>& dims) {
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kFloat32);
// 创建张量并按 Z, Y, X 顺序加载数据
torch::Tensor image_tensor = torch::from_blob(
(void*)imageData.data(),
{static_cast<int>(dims[2]), static_cast<int>(dims[1]), static_cast<int>(dims[0])},
options
);
// 使用 permute 调整张量维度顺序为 (X, Y, Z)
torch::Tensor permuted_tensor = image_tensor.permute({2, 1, 0}).clone();
std::cout << "Tensor Shape: " << permuted_tensor.sizes() << std::endl;
std::cout << std::endl;
return permuted_tensor;
}
3.3 Tensor保存为Nii 文件
void SaveTensorAsNii(const torch::Tensor& tensor, const std::string& fileName, const ImageType::SpacingType& spacing,const ImageType::PointType& origin, const ImageType::DirectionType& direction) {
itk::NiftiImageIOFactory::RegisterOneFactory();
//将[Z, Y, X]转回[X, Y, Z]
torch::Tensor permuted_tensor = tensor.permute({ 2, 1, 0 });
auto sizes = tensor.sizes();
unsigned int sizeX = sizes[0];
unsigned int sizeY = sizes[1];
unsigned int sizeZ = sizes[2];
ImageType::Pointer itkImage = ImageType::New();
ImageType::RegionType region;
ImageType::IndexType start = { 0, 0, 0 };
ImageType::SizeType size = { sizeX, sizeY, sizeZ };
region.SetIndex(start);
region.SetSize(size);
itkImage->SetRegions(region);
itkImage->Allocate();
itkImage->FillBuffer(0);
itkImage->SetSpacing(spacing);
itkImage->SetOrigin(origin);
//itkImage->SetDirection(direction);
auto* itkImageBuffer = itkImage->GetBufferPointer();
torch::Tensor contiguousTensor = permuted_tensor.contiguous();
std::memcpy(itkImageBuffer, contiguousTensor.data_ptr<float>(), contiguousTensor.numel() * sizeof(float));
using WriterType = itk::ImageFileWriter<ImageType>;
WriterType::Pointer writer = WriterType::New();
writer->SetFileName(fileName);
writer->SetInput(itkImage);
writer->Update();
}
4. 说明
如输入nii图像为【512,512,309】,正确读取的tensor将为【309,512,512】,permute之后为【512,512,309】。

1308

被折叠的 条评论
为什么被折叠?



