一文捋清【reshape、view、rearrange、contiguous、transpose、squeeze、unsqueeze】——python & torch

一文捋清【reshape、view、rearrange、contiguous、transpose、squeeze、unsqueeze】

1. reshape

reshape() 函数: 用于在不更改数据的情况下为数组赋予新形状。
注意: 用于低维度转高维度

c = np.arange(6)
print("** ", c)
c1 = c.reshape(3, -1)
print("** ", c1)
c2 = c.reshape(-1, 6)
print("** ", c2)
**  [0 1 2 3 4 5]
**  [[0 1]
 [2 3]
 [4 5]]
**  [[0 1 2 3 4 5]]

2. view

torch中,view() 的作用相当于numpy中的reshape,重新定义矩阵的形状。

v1 = torch.range(1, 16) 
v2 = v1.view(-1, 4)  
print("** ", v1)
print("** ", v2)
**  tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        15., 16.])
**  tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])

3. rearrange

rearrange是einops中的一个函数调用方法。

from einops import rearrange

image1 = torch.zeros(2, 224, 224, 3)
image2 = rearrange(image1, 'b w h c -> b (w h) c')
image3 = rearrange(image1, 'b w h c -> (b c) (w h)')

print("** ", image1.shape)
print("** ", image2.shape)
print("** ", image3.shape)
**  torch.Size([2, 224, 224, 3])
**  torch.Size([2, 50176, 3])
**  torch.Size([6, 50176])

4. transpose

torch.transpose(Tensor,dim0,dim1)是pytorch中的ndarray矩阵进行转置的操作。
注意: transpose()一次只能在两个维度间进行转置(也可以理解为维度转换)

x = torch.Tensor(2, 3, 4, 5)  # 这是一个4维的矩阵(只用空间位置,没有数据)
print(x.shape)
# 先转置0维和1维,之后在第2,3维间转置,之后在第1,3间转置
y = x.transpose(0, 1).transpose(3, 2).transpose(1, 3)
print(y.shape)
torch.Size([2, 3, 4, 5])
torch.Size([3, 4, 5, 2])

5. permute

注意: permute相当于可以同时操作于tensor的若干维度,transpose只能同时作用于tensor的两个维度,permute是transpose的进阶版。

print(torch.Tensor(2,3,4,5).permute(3,2,0,1).shape)
torch.Size([5, 4, 2, 3])

6. contiguous

x.is_contiguous() ——判断tensor是否连续
x.contiguous() ——把tensor变成在内存中连续分布的形式

需要变成连续分布的情况:

view只能用在contiguous的variable上。如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguous copy。

x = torch.Tensor(5, 10)
print(x.is_contiguous())
print(x.transpose(0, 1).is_contiguous())
print(x.transpose(0, 1).contiguous().is_contiguous())
True
False
True

写代码时,一般没写contiguous()会报错提示,所以不用担心…

7. squeeze

squeeze()函数的功能是维度压缩。返回一个tensor(张量),其中 input 中大小为1的所有维都已删除。

x = torch.Tensor(2, 1, 2, 1, 2)
print(x.shape)
y = torch.squeeze(x) # 默认是把所有是1的维度都删掉
print(y.shape)
y = torch.squeeze(x, 1)
print(y.shape)
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 2])
torch.Size([2, 2, 1, 2])

8. unsqueeze

squeeze()函数的功能是增加维度

x = torch.arange(0,6)
print(x.shape)
y = x.unsqueeze(0)
print(y.shape)
z = x.unsqueeze(1)
print(z.shape)
w = x.unsqueeze(2)
print(w.shape)
torch.Size([6])
torch.Size([1, 6])
torch.Size([6, 1])
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
  • 10
    点赞
  • 49
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
class _PointnetSAModuleBase(nn.Module): def init(self): super().init() self.npoint = None self.groupers = None self.mlps = None self.pool_method = 'max_pool' def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): """ :param xyz: (B, N, 3) tensor of the xyz coordinates of the features :param features: (B, N, C) tensor of the descriptors of the the features :param new_xyz: :return: new_xyz: (B, npoint, 3) tensor of the new features' xyz new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors """ new_features_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() if new_xyz is None: new_xyz = pointnet2_utils.gather_operation( xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) ).transpose(1, 2).contiguous() if self.npoint is not None else None for i in range(len(self.groupers)): new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) new_features = self.mlpsi # (B, mlp[-1], npoint, nsample) if self.pool_method == 'max_pool': new_features = F.max_pool2d( new_features, kernel_size=[1, new_features.size(3)] ) # (B, mlp[-1], npoint, 1) elif self.pool_method == 'avg_pool': new_features = F.avg_pool2d( new_features, kernel_size=[1, new_features.size(3)] ) # (B, mlp[-1], npoint, 1) else: raise NotImplementedError new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) new_features_list.append(new_features) return new_xyz, torch.cat(new_features_list, dim=1)你可以给我详细讲解一下这个模块吗,一个语句一个语句的来讲解
05-24
这是一个名为_PointnetSAModuleBase的PyTorch模块,是PointNet++中的一个重要组成部分。它是一个基类,用于定义PointNet++中的两种不同的SA(Sampling and Aggregation)模块。接下来,我会一行一行地解释这个模块的代码。 ``` class _PointnetSAModuleBase(nn.Module): ``` 这里定义了一个名为_PointnetSAModuleBase的PyTorch模块,并继承了nn.Module类。 ``` def __init__(self): super().__init__() self.npoint = None self.groupers = None self.mlps = None self.pool_method = 'max_pool' ``` 这里定义了_PointnetSAModuleBase类的构造函数,并初始化了四个实例变量:self.npoint、self.groupers、self.mlps和self.pool_method。其中,self.npoint是采样点数,self.groupers是对每个采样点进行聚合的模块,self.mlps是一个包含多个MLP(Multi-Layer Perceptron)层的列表,self.pool_method是池化方法,具体可以是最大池化或平均池化。 ``` def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): ``` 这里定义了_PointnetSAModuleBase类的前向传播函数,输入包括xyz点云坐标张量、features特征张量和new_xyz新的采样点云坐标张量。返回值是一个包含new_xyz和new_features的元组。其中,new_features是根据new_xyz和features计算得到的新特征张量。 ``` new_features_list = [] xyz_flipped = xyz.transpose(1, 2).contiguous() ``` 这里定义了一个空列表new_features_list和一个翻转了xyz张量维度的张量xyz_flipped。 ``` if new_xyz is None: new_xyz = pointnet2_utils.gather_operation( xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) ).transpose(1, 2).contiguous() if self.npoint is not None else None ``` 这里判断new_xyz是否为空,如果是,则使用furthest_point_sample函数进行采样,得到一个新的采样点云坐标张量new_xyz。如果self.npoint为空,则将new_xyz设为None。 ``` for i in range(len(self.groupers)): new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) new_features = self.mlpsi # (B, mlp[-1], npoint, nsample) if self.pool_method == 'max_pool': new_features = F.max_pool2d( new_features, kernel_size=[1, new_features.size(3)] ) # (B, mlp[-1], npoint, 1) elif self.pool_method == 'avg_pool': new_features = F.avg_pool2d( new_features, kernel_size=[1, new_features.size(3)] ) # (B, mlp[-1], npoint, 1) else: raise NotImplementedError new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) new_features_list.append(new_features) ``` 这里遍历self.groupers列表,并对每个采样点进行聚合。对于每个聚合模块,首先将xyz、new_xyz和features传递给它,得到新的new_features张量。然后,将new_features传递给一个包含多个MLP层的列表self.mlps,得到新的new_features张量。接着,根据self.pool_method的值,对new_features张量进行最大池化或平均池化。最后,将new_features张量的最后一个维度压缩掉,并将结果添加到new_features_list列表中。 ``` return new_xyz, torch.cat(new_features_list, dim=1) ``` 这里返回new_xyz和new_features_list的拼接结果。其中,new_features_list的维度为(B, \sum_k(mlps[k][-1]), npoint),表示每个采样点的特征向量。最后,使用torch.cat函数在第二个维度上进行拼接,得到最终的new_features张量,维度为(B, \sum_k(mlps[k][-1]), npoint)。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

青春是首不老歌丶

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值