Pytorch量处理中的一些小细节(reshape和view,cat和Unet中心剪裁以及警告:UserWarning: non-inplace resize is deprecated warn)

view与reshape

关于连续

 reshape和view区别的验证

张量不连续如何解决

关于cat()

 举几个例子

两个维度数不一样的张量进行cat拼接

两个二维张量进行沿行拼接,列不变

 两个二维张量进行沿列拼接,行不变

两个三位张量进行沿通道拼接, 其他维度数不变

 关于中心裁剪

关于UserWarning: non-inplace resize is deprecated warnings.warn("non-inplace resize is deprecated")


view与reshape

reshape方法会尝试创建一个新的张量,其元素与原始张量共享内存空间。如果原始张量在内存中是连续的,reshape将返回一个指向相同数据的新张量视图;如果原始张量在内存中不连续,reshape可能会先将其复制为连续的张量,然后再返回新形状的张量。这意味着,在某些情况下,reshape可能会导致额外的内存开销。

view方法则不会创建一个新的张量,而是直接返回一个与原始张量共享数据存储的新视图。如果原始张量和新的视图张量上的元素被修改,它们会互相影响,因为它们共享相同的数据。

注意:上述话的的意思其实就是,reshape方法会创建一个新张量,而view方法不会,他们两个返回的数据都是和原始张量在同一片空间,这一点我们可以使用resize_()方法进行验证,然后reshape会依据原始张量是否连续来进行相应的操作,但是view不行,如果原始数据不连续的话他就没办法操作里,(就会报错)。

这里提到了张量的连续性,那么这个连续性究竟是什么?

关于连续

如果原始张量是连续的,那么它的一维展开顺序与按行优先存储的顺序一致。

举个例子:        

首先我们要知道:在PyTorch中,张量(Tensor)的实际数据以一维数组(storage)的形式存储于连续的内存中,通常采用“行优先”的存储方式。这意味着,当张量被视为多维数组时,其元素在内存中的排列顺序是按照行(或称为“外层维度”)来确定的。

而转置(transpose)操作会改变张量的维度顺序,但不会改变其底层数据的物理位置。例如,一个形状为(m, n)的二维张量在转置后,其形状变为(n, m),但转置前后的张量共享同一个底层存储(storage)

例如:

假设有一个形状为(2, 3)的二维张量a,其内容为[[1, 2, 3], [4, 5, 6]]。在内存中,这个张量的一维展开顺序是[1, 2, 3, 4, 5, 6]。当对a进行转置操作得到张量b(形状为(3, 2))时,虽然b的内容仍然是[1, 2, 3, 4, 5, 6],但按照(3, 2)的形状来访问时,其元素在内存中的排列顺序就不再是[1, 2, 3, 4, 5, 6]了(实际上,由于共享同一个storage,物理上的排列顺序并未改变,但逻辑上的访问顺序变了)。因此,b被视为非连续张量。

所以我们就说,如果一个连续的数据转置后,他一维展开顺序发生变化,但是他在物理存储顺序还是和之前一样,所以就不连续。

判断是否连续的方法:.is_contiguous()

 reshape和view区别的验证

import torch

a=[[1,2,3],[1,2,3]]
b1=torch.tensor(a)
b2=torch.tensor(a)
b3=torch.tensor(a)
print('b1是否连续:',b1.is_contiguous())
print('b2是否连续:',b2.is_contiguous())
print('b3是否连续:',b2.is_contiguous())
print('b1.shape:',b1.shape)
print('b2.shape:',b1.shape)
print('b3.shape:',b2.shape)
print('*---------------------------*')
c1=b1.reshape(3,2)
c2=b2.view(3,2)
print('c1:',c1)
print('c1的地址:',id(c1))
c1.resize_(2,3)
print('c1',c1)
print('b1:',b1)
print('c1的地址:',id(c1))
print('c1是否连续:',c1.is_contiguous())
print('*---------------------------*')
print('c2:',c2)
print('c2的地址:',id(c2))
c2.resize_(2,3)
print('c2',c2)
print('b2:',b2)
print('c2的地址:',id(c2))
print('c2是否连续:',c2.is_contiguous())
print('*---------------------*')
b3=torch.t(b3)
print('b3:',b3)
print('b3是否连续:',b3.is_contiguous())
print('对b3使用view方法:')
print(b3.view(2,3))

 输出结果:

b1是否连续: True
b2是否连续: True
b3是否连续: True
b1.shape: torch.Size([2, 3])
b2.shape: torch.Size([2, 3])
b3.shape: torch.Size([2, 3])
*---------------------------*
c1: tensor([[1, 2],
        [3, 1],
        [2, 3]])
c1的地址: 2769190187152
c1 tensor([[1, 2, 3],
        [1, 2, 3]])
b1: tensor([[1, 2, 3],
        [1, 2, 3]])
c1的地址: 2769190187152
c1是否连续: True
*---------------------------*
c2: tensor([[1, 2],
        [3, 1],
        [2, 3]])
c2的地址: 2769190187232
c2 tensor([[1, 2, 3],
        [1, 2, 3]])
b2: tensor([[1, 2, 3],
        [1, 2, 3]])
c2的地址: 2769190187232
c2是否连续: True
*---------------------*
b3: tensor([[1, 1],
        [2, 2],
        [3, 3]])
b3是否连续: False
对b3使用view方法:
Traceback (most recent call last):
  File "D:\论文\test.py", line 36, in <module>
    print(b3.view(2,3))
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

从这里可以看出,他们两个返回的数据都是和原始张量在同一片空间,且张量不连续,使用view会报错。

张量不连续如何解决

 可以使用.contiguous()方法确保张量是内存连续的。但是,请注意,.contiguous()会返回张量的一个副本,如果原始张量很大,这可能会消耗额外的内存

使用方式:

x = torch.randn(2, 3, 4).transpose(1, 2).contiguous()  

或者更换处理方式使用reshape()

如果以上都不行,就要好好考虑自身的算法设计了。

关于cat()

torch.cat(tensors, dim=0, *, out=None) -> Tensor
  • tensors (sequence of tensors) – 要拼接的张量序列。这些张量必须具有相同的形状,除了指定要拼接的维度 dim
  • dim (int, 可选) – 要拼接的维度。默认为 0。(沿着第一个维度(dim=0)拼接,沿第二个维度(dim=1)拼接,依次类推)
  • out (Tensor, 可选) – 输出张量。如果指定,则结果将被写入此张量。
  • 注意:这里要注意的就是,对于我们的pytorch张量来说,沿第一个维度拼接就是沿着行拼接,对于二维张量而言,沿行拼接列不变,沿列拼接行不变。但是拼接的双方的维度数必须一样。

返回值

返回一个新的张量,它是输入张量在指定维度上的拼接结果。

 举几个例子

两个维度数不一样的张量进行cat拼接

import torch

a=[[1,2,3],[1,2,3]]
b=[1,2,3]
b1=torch.tensor(a)
b2=torch.tensor(b)

print('b1.shape:',b1.shape)
print('b2.shape:',b1.shape)

print('*---------------------------*')
c=torch.cat((b1,b2),dim=0)
print(c)

结果:

b1.shape: torch.Size([2, 3])
b2.shape: torch.Size([2, 3])
*---------------------------*
Traceback (most recent call last):
  File "D:\论文\test.py", line 12, in <module>
    c=torch.cat((b1,b2),dim=0)
RuntimeError: Tensors must have same number of dimensions: got 2 and 1

 这里就是报错:两个张量的维度数目不一样。

两个二维张量进行沿行拼接,列不变

这句话的意思是进行行拼接的时候,列数目要一致,行数倒是无所谓了。

import torch

a=[[1,2,3],[1,2,3]]
b=[[1,1,1]]
b1=torch.tensor(a)
b2=torch.tensor(b)

print('b1.shape:',b1.shape)
print('b2.shape:',b2.shape)

print('*---------------------------*')
c=torch.cat((b1,b2),dim=0)
print(c)

结果:

b1.shape: torch.Size([2, 3])
b2.shape: torch.Size([1, 3])
*---------------------------*
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 1, 1]])

 两个二维张量进行沿列拼接,行不变

这句话的意思是进行列拼接的时候,行数目要一致,列数倒是无所谓了。

import torch

a=[[1,2,3],[1,2,3]]
b=[[1,1],[1,1]]
b1=torch.tensor(a)
b2=torch.tensor(b)

print('b1.shape:',b1.shape)
print('b2.shape:',b2.shape)

print('*---------------------------*')
c=torch.cat((b1,b2),dim=1)
print(c)

结果:

b1.shape: torch.Size([2, 3])
b2.shape: torch.Size([2, 2])
*---------------------------*
tensor([[1, 2, 3, 1, 1],
        [1, 2, 3, 1, 1]])

两个三位张量进行沿通道拼接, 其他维度数不变

这句话的意思是进行沿通道拼接的时候,其他维度数目要一致,通道数倒是无所谓了。

import torch

a=[[[1,2,3]],[[1,2,3]]]
b=[[[1,1,1]],[[1,1,1]]]
b1=torch.tensor(a)
b2=torch.tensor(b)

print('b1.shape:',b1.shape)#通道数,行数,列数
print('b2.shape:',b2.shape)#通道数,行数,列数

print('*---------------------------*')
c=torch.cat((b1,b2),dim=0)
print(c)
print(c.shape)

结果:

b1.shape: torch.Size([2, 1, 3])
b2.shape: torch.Size([2, 1, 3])
*---------------------------*
tensor([[[1, 2, 3]],

        [[1, 2, 3]],

        [[1, 1, 1]],

        [[1, 1, 1]]])
torch.Size([4, 1, 3])

 

 关于中心裁剪

中心裁剪的原理其实很简单:

图像中心裁剪是指在保持图像宽高比不变的情况下,将图像的边缘部分删除,从而得到一个更小的图像。中心裁剪通常用于图像预处理、调整图像大小、去除图像噪声等应用场景。

图像中心裁剪的原理很简单,即计算出需要裁剪的区域的起始坐标和结束坐标,然后使用这些坐标对图像进行切割。具体步骤如下:

  1. 获取图像的宽度和高度。
  2. 根据需要裁剪的尺寸,计算出裁剪区域的起始坐标和结束坐标。
  3. 使用起始坐标和结束坐标,对图像进行切割。

总而言之就是:裁剪图像,并且保证宽高比不变。

这里为什么要提起中心裁剪和cat,是因为本人在复现语义分割中的一个经典模型——Unet模型,遇到了的一个小问题 。

Unet模型的网络架构如下:

这里面很关键的一点就是copy_crop操作,该操作是先把对应步骤的张量复制,然后在某一步骤开始前进行中心裁剪和另一张量进行cat拼接,从而得到一个新的张量,用以后续步骤的进行。

中心裁剪其实可以手动实现,首要的就是要找到“中心”,可以通过如下方式找到:

【目标张量的某维度数和原来张量的对应维度数的差值整除以2:目标张量某维度数-1】

以下是本人手动实现中心裁剪的代码:

    def copy_crop(self,tensor,target_tensor):
        root_shape=tensor.shape[2]
        target_shape=target_tensor[2]
        mid_different=(root_shape-target_shape)//2

        return tensor[:,:,mid_different:mid_different+(target_shape-1),mid_different:mid_different+(target_shape-1)]

 

关于UserWarning: non-inplace resize is deprecated warnings.warn("non-inplace resize is deprecated")

这条警告信息的意思是:“非原地(non-inplace)的 resize 操作已被弃用。” 在 PyTorch 中,resize_(注意末尾的下划线)是原地修改张量形状的方法,它不会创建新的张量对象,而是直接在原张量上修改其形状。而 resize 方法(不带下划线)虽然也能改变张量的形状,但它不是原地操作,可能会引起一些不必要的内存分配和性能问题,因此被标记为弃用。

简单来说,你应该使用 resize_ 方法来原地修改张量的形状,而不是 resize 方法,以避免这条警告信息。如果确实需要保持原有的张量不变,可以考虑先复制(例如使用 .clone())张量,然后再在复制后的张量上使用 resize_ 或其他方法来修改形状。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值