4 | PyTorch张量操作:底层存储逻辑

关于张量的底层存储逻辑这一部分看的我有点头大,但是了解底层实现确实有助于理解tensor中的各种运算到底是怎么一个回事,当然大部分时间我们可以不太会用到这些存储操作,但是熟悉这些底层实现,我觉得一方面可以帮我屏蔽一些开发上的bug,或者说在查bug的时候会往这个方面思考;再一个就是如果真的有需要做比较硬核的优化的时候也能够有点想法。

张量的存储

前面我们说过,张量的存储空间是连续的,最开始我可能以为存储像张量的结构一样,
比如说像这样的方块区域

但是,实际上它是这样存储的

image.png


然后使用偏移量和步长来进行索引,关于这两个概念我们后面会讨论。

PyTorch提供了一个storage方法来访问内存,如下我们创建了一个三行二列的二维tensor,然后用storage()读取它的内存,我们可以看到结果,实际底层存储是一个size为6的连续数组,而我们的tensor方法所实现的就是怎么通过索引把数组转换成我们需要的张量以及各种运算的方法。

points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points.storage()
outs:
 4.0
 1.0
 5.0
 3.0
 2.0
 1.0
[torch.FloatStorage of size 6]

我们可以使用索引来查询这个存储区,比如

points_storage = points.storage()
points_storage[0]
outs:
4.0

显而易见的是,我们不能用二维索引,因为这个存储区只是一个一维数组,同时,如果我们修改存储区的数据,那么tensor的数据自然而然会发生变化。

points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points_storage = points.storage()
points_storage[0] = 2.0 #给存储区位置0赋值2
points
outs:tensor([[2., 1.],
        [5., 3.],
        [2., 1.]])

关于带下划线的操作

在tensor的操作中,有少量的方法是带下划线的,比如zero_(),这样的方法只作为tensor对象的方法,我们可以认为是原地操作的方法,也就是说这样的方法是直接修改输入然后返回结果,而对应不带下划线的方法不会去改变源tensor,而是返回一个新的tensor。
让我们看一段代码:

import torch

a = torch.ones(3, 2)
a
outs:tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])

b = a.zero_()
b
outs:tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])

a
outs:tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])

可以看到使用了zero_()方法之后,虽然我们看起来赋值给了b,但实际上底层发生了变化,a的数值也都是0了。

元数据是如何计算的

既然我们已经知道了tensor的底层存储实际上是连续的一维数组,那么下面来了解一下tensor通过什么样的方式来把底层存储处理成上层实现。

大小、偏移量、步长

这里作者给了三个概念,就是张量的大小、偏移量和步长,作者手绘的图像如下

大小(size):大小这个概念很容易理解,比如说图中给的tensor在表现上来看是一个3*3的矩阵,tensor的大小就是一个元组,里面记录了每一个维度有多少元素。

偏移量(offset):偏移量指的是这个tensor的第一个元素在当前存储区上的位置索引。我理解是这样的,对于一个完整的tensor,offset都是0。但是在某些情况,比如说我们有一个4*4的tensor,我们从它的(1,1)的位置选取一个子tensor,这个时候这个子tensor的offset就不是0了,应该是5?

步长(stride):这个概念我在抽象层面能够理解,但是实际看了例子还是差了一点,花了好长时间才搞明白一些。为此还专门去查了stride的英文意思,stride有“跨过,步幅的意思”,在这里去理解它,是指的按照tensor的顺序,沿着一个维度获取下一个元素在实际存储区所需要跳过的元素数量。

比如说上面的例子里,沿着行这个维度获取下一个元素也就是5->1这个动作,在存储区需要跨过3个元素,而沿着列这个维度获取下一个元素5->7这个动作,只需要跨过1个元素就可以了。

我们可以通过代码来查看偏移量和步长。

points = torch.tensor([[4.0, 1.0, 3.0, 2.0], [5.0, 3.0, 7.0, 8.0], [2.0, 1.0, 9.0, 5.0],[3.0, 8.0, 4.0, 5.0]]) #先生成一个新的tensor
second_point = points[1:,1:] #从原始tensor中摘取一个子tensor
second_point #让我们看看截取的子tensor对不对
outs:tensor([[3., 7., 8.],
        [1., 9., 5.],
        [8., 4., 5.]])
points.storage_offset() #原tensor的偏移量
outs:0
second_point.storage_offset() #子tensor的偏移量
outs:5 #看起来跟我们猜测的一样

#再来看一下步长
points.stride() #原始tensor的步长
outs:(4,1)
second_point.stride() #子tensor的步长
outs:(4,1)

可以看到这里的原始tensor和子tensor的步长都是一样的,这是为什么呢,很容易理解啊,我们是从(1,1)开始截取的,在底层存储不变的情况下,子tensor要按维度跳到下一个元素位置所经过的元素跟原tensor是一样的!

因此,我们修改子tensor也会引起原tensor的变化。如果要开辟一块新的空间来存这个tensor可以使用clone方法,这时候second_point就在一个新的tensor存储空间,对其修改不会影响points

second_point = points[1:,1:].clone()
second_point[0,0] = 10.0
second_point
outs:tensor([[10.,  7.,  8.],
        [ 1.,  9.,  5.],
        [ 8.,  4.,  5.]])

points
outs:tensor([[4., 1., 3., 2.],
        [5., 3., 7., 8.],
        [2., 1., 9., 5.],
        [3., 8., 4., 5.]])

如果说在这里似乎还看不出这个存储方案有什么神奇之处,下面我们看看对tensor进行操作之后的情况。

转置之后发生了啥

我们重新构建一个tensor

points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points
outs:tensor([[4., 1.],
        [5., 3.],
        [2., 1.]])

points_t = points.t() #t()方法是用于二维张量转置时对transpose()方法的简写
points_t
outs:tensor([[4., 5., 2.],
        [1., 3., 1.]])

转置之后发生了什么呢,其实什么都没有发生,存储区还是一个存储区,变的只是tensor对于存储区的索引结构

#验证这两个tensor是用的一个存储区
id(points.storage()) == id(points_t.storage())

我们来看一下步长的变化

points.stride()
outs:(2, 1)
points_t.stride()
out:(1, 2)

从上面的代码我们可以看出,转置之后不同维度的步长做了相应的调整,示例图如下(突然发现原图有问题,我重新画了一个),转置后的tensor按行维度找下一个元素也就是4->1,只需要跨过1个元素,同理,在列维度则需要跨过2个元素。

什么是连续张量

连续张量的概念貌似很拗口,反正我看翻译是没有看懂,所以我把原文放在下面了,大意是有这样一个张量,它的值以最右侧的维度开始按顺序在存储区间中排列,这种张量就是连续张量。虽然概念很拗口,但是理念是很简单的,这里举了一个例子:比如说一个二维tensor,沿着行移动。

A tensor whose values are laid out in the storage starting from the rightmost dimension onward (that is, moving along rows for a 2D tensor) is defined as contiguous

再来看实际的代码,就更容易理解了:

points.is_contiguous()
outs:True
points_t.is_contiguous()
outs:False

在tensor的顺序和存储区顺序一致的就是连续张量,否则就不是。在PyTorch中,有一些操作只针对连续张量起作用,如果我们对那些不是连续张量的张量实施这些操作就会报错。那么如果我们想用这些方法怎么办呢,PyTorch自然也给出了解决办法,那就是contiguous方法,使用这个方法会改变存储区存储顺序,使得存储区顺序符合当前tensor连续的要求。

points = torch.tensor([[4.0, 1.0], [5.0, 3.0], [2.0, 1.0]])
points_t = points.t()
points_t #查看转置后的tensor
outs:tensor([[4., 5., 2.],
        [1., 3., 1.]])
points_t.storage() #查看存储区顺序
 outs:
 4.0
 1.0
 5.0
 3.0
 2.0
 1.0
[torch.FloatStorage of size 6]

points_t.stride() #查看步长信息
outs:(1, 2)

points_t_cont = points_t.contiguous() #调用contiguous方法
points_t_cont #可以看到tensor的表示没有发生变化
outs:tensor([[4., 5., 2.],
        [1., 3., 1.]])

points_t_cont.stride() #但是步长信息变了
outs:(3, 1)

points_t_cont.storage() #再看一下存储区,已经发生了变化
outs:
 4.0
 5.0
 2.0
 1.0
 3.0
 1.0
[torch.FloatStorage of size 6]

今天就看这么多吧。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值