三维图像旋转(基于Pytorch)

楼主本人有时需要对三维图像进行旋转,但是找了几天竟然没有发现合适的代码,试过scipy,可以对三维图像旋转,但是也太慢了,1000*1000*1000的数据绕一个维度要200秒左右。试过skimage,这个只能对2d图像旋转,结果绕某个轴一层层旋转,反而比scipy要快一点点,但是也太慢了。后来尝试了pytorch的F.affine_grid和F.grid_sample, 虽然能快速旋转,但是会把图像拉伸,造成失真,如下图所式(来源:Pytorch中的仿射变换(affine_grid) - 简书 (jianshu.com)):

原图像
仿射变换后的图像

最后找了半天,发现torchvion比较好用,虽然他原本是用来处理2d图像的,如果处理三维图像,会默认对第一个维度旋转,因此如果要旋转第二或第三个维度,需要先转置。这下基本思路就有了,可以直接放代码了:

import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms.functional import rotate
from torchvision.transforms import InterpolationMode

def rotation_3d(X, axis, theta, expand=False, fill=0.0):
    """
    The rotation is based on torchvision.transforms.functional.rotate, which is originally made for a 2d image rotation
    :param X: the data that should be rotated, a torch.tensor or an ndarray
    :param axis: the rotation axis based on the keynote request. 0 for x axis, 1 for y axis, and 2 for z axis.
    :param expand:  (bool, optional) – Optional expansion flag. If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. Note that the expand flag assumes rotation around the center and no translation.
    :param fill:  (sequence or number, optional) –Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively.
    :param theta: the rotation angle, Counter-clockwise rotation, [-180, 180] degrees.
    :return: rotated tensor.
    """
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    if type(X) is np.ndarray:
        X = torch.from_numpy(X)
        X = X.float()

    X = X.to(device)

    if axis == 0:
        X = rotate(X, interpolation=InterpolationMode.BILINEAR, angle=theta, expand=expand, fill=fill)
    elif axis == 1:
        X = X.permute((1, 0, 2))
        X = rotate(X, interpolation=InterpolationMode.BILINEAR, angle=theta, expand=expand, fill=fill)
        X = X.permute((1, 0, 2))
    elif axis == 2:
        X = X.permute((2, 1, 0))
        X = rotate(X, interpolation=InterpolationMode.BILINEAR, angle=-theta, expand=expand, fill=fill)
        X = X.permute((2, 1, 0))
    else:
        raise Exception('Not invalid axis')
    return X.squeeze(0)

if __name__ == "__main__":
    input_data = np.ones((300,300,300))
    input_data[0:250,0:250,0:250] = 0.75
    input_data[0:150,0:150,0:150] = 0.5
    input_data[0:50,0:50,0:50] = 0.25
    input_data = np.pad(input_data, ((100, 100), (100, 100), (100, 100)), 'constant', constant_values=((0.0, 0.0), (0.0, 0.0), (0.0, 0.0)))
    theta = 30
    
    output1 = rotation_3d(input_data, 0, theta)
    output2 = rotation_3d(output1, 1, theta)
    output3 = rotation_3d(output1, 2, theta)

    fig, axes = plt.subplots(2, 2, figsize=(8, 8))
    ax = axes.ravel()
    ax[0].imshow(input_data[120, :, :], cmap='gray')
    ax[0].set_title('Original image')
    ax[1].imshow(output1.cpu()[140, :, :], cmap='gray')
    ax[1].set_title('Rotated image around x axis')
    ax[2].imshow(output2.cpu()[140, :, :], cmap='gray')
    ax[2].set_title('Rotated image around y axis')
    ax[3].imshow(output3.cpu()[140, :, :], cmap='gray')
    ax[3].set_title('Rotated image around z axis')
    fig.suptitle('1')
    plt.show()

经过测试,在图像很小时,CPU更占优势,图像较大时,GPU远快于CPU,最终旋转结果如下。

 github地址:flashrouster/rotation_3d: 3D rotation around any axis based on Pytorch (github.com)icon-default.png?t=M5H6https://github.com/flashrouster/rotation_3d

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
分类模型 本文将介绍如何使用PyTorch构建一个水果图像分类模型。我们将使用一个小型数据集,由3种水果组成:苹果,香蕉和橙子。我们将使用卷积神经网络(Convolutional Neural Network, CNN)来训练模型。 1. 准备数据 我们将使用一个小型数据集,由3种水果组成:苹果,香蕉和橙子。我们将从Kaggle下载该数据集,下载后将其放在本地目录下的/data/fruits/下。 接下来,我们需要将数据集分成训练集和测试集。我们将80%的数据用于训练,20%的数据用于测试。我们还将使用PyTorch中的ImageFolder类来加载数据集,该类将自动将图像与其相应的类别进行匹配。 以下是准备数据的代码: ``` import torch import torchvision import torchvision.transforms as transforms # 数据集路径 data_path = '/data/fruits/' # 定义训练集和测试集的转换 train_transform = transforms.Compose([ transforms.Resize((64, 64)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(10), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) test_transform = transforms.Compose([ transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 加载数据集 train_set = torchvision.datasets.ImageFolder(root=data_path + 'train', transform=train_transform) test_set = torchvision.datasets.ImageFolder(root=data_path + 'test', transform=test_transform) # 定义数据加载器 train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True) test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, shuffle=False) ``` 在上面的代码中,我们首先定义了数据集的路径。接下来,我们定义了训练集和测试集的转换。在这里,我们使用了一些数据增强技术,例如随机水平翻转和随机旋转。这些技术可以帮助模型更好地泛化。 我们还使用了归一化技术,将图像像素的值缩放到[-1,1]之间。这样做是为了使输入数据的分布更加均匀,从而加速模型的训练。 最后,我们使用ImageFolder类加载数据集,并定义数据加载器。数据加载器可以方便地将数据集分成小批次,以便我们能够更快地训练模型。 2. 构建模型 我们将使用一个简单的卷积神经网络(CNN)来训练模型。该模型由三个卷积层和三个全连接层组成。我们还将使用dropout技术来减少过拟合。 以下是构建模型的代码: ``` import torch.nn as nn import torch.nn.functional as F class FruitNet(nn.Module): def __init__(self): super(FruitNet, self).__init__() self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.fc1 = nn.Linear(128 * 8 * 8, 512) self.fc2 = nn.Linear(512, 128) self.fc3 = nn.Linear(128, 3) self.dropout = nn.Dropout(0.5) def forward(self, x): x = F.relu(self.conv1(x)) x = self.pool(x) x = F.relu(self.conv2(x)) x = self.pool(x) x = F.relu(self.conv3(x)) x = self.pool(x) x = x.view(-1, 128 * 8 * 8) x = F.relu(self.fc1(x)) x = self.dropout(x) x = F.relu(self.fc2(x)) x = self.dropout(x) x = self.fc3(x) return x ``` 在上面的代码中,我们首先定义了一个名为FruitNet的类,该类继承自nn.Module类。该类包含了三个卷积层和三个全连接层。在卷积层之间我们使用了max-pooling层。我们还使用了dropout技术来减少过拟合。 在forward方法中,我们首先将输入x通过卷积层和max-pooling层传递。接下来,我们将输入x展开成一维向量,并通过全连接层传递。最后,我们使用softmax函数将输出转换为概率分布。 3. 训练模型 现在我们已经准备好训练模型了。我们将使用交叉熵损失函数和随机梯度下降(SGD)优化器来训练模型。 以下是训练模型的代码: ``` import torch.optim as optim # 定义模型、损失函数和优化器 net = FruitNet() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) # 训练模型 for epoch in range(10): running_loss = 0.0 for i, data in enumerate(train_loader, 0): # 输入数据和标签 inputs, labels = data # 梯度清零 optimizer.zero_grad() # 前向传播、反向传播和优化 outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 打印统计信息 running_loss += loss.item() if i % 100 == 99: # 每100个小批次打印一次统计信息 print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100)) running_loss = 0.0 ``` 在上面的代码中,我们首先定义了模型、损失函数和优化器。在训练过程中,我们首先将梯度清零,然后将输入数据通过模型传递,并计算损失。接下来,我们执行反向传播和优化。最后,我们打印统计信息。 4. 测试模型 现在我们已经训练好了模型,我们需要测试它的性能。我们将使用测试集来测试模型的准确性。 以下是测试模型的代码: ``` # 测试模型 correct = 0 total = 0 with torch.no_grad(): for data in test_loader: # 输入数据和标签 images, labels = data # 前向传播 outputs = net(images) # 预测标签 _, predicted = torch.max(outputs.data, 1) # 统计信息 total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the test images: %d %%' % ( 100 * correct / total)) ``` 在上面的代码中,我们首先定义了正确分类的数量和总数。使用no_grad上下文管理器可以关闭autograd引擎,从而加速模型的运行。在测试集上,我们将输入数据通过模型传递,并获得预测标签。最后,我们统计了正确分类的数量和总数,并打印了模型的准确率。 总结 本文介绍了如何使用PyTorch构建一个水果图像分类模型。我们首先准备了数据集,然后构建了一个简单的卷积神经网络。我们还使用了交叉熵损失函数和随机梯度下降(SGD)优化器来训练模型。最后,我们使用测试集测试了模型的性能。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值