Pytorch基础:Tensor的连续性

相关阅读

Pytorch基础icon-default.png?t=N7T8https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


        在Pytorch中,一个连续的张量指的是张量中各数据元素在底层的存储顺序与其在张量中的位置一致。这意味着每一个元素的地址可以通过下面的线性映射公式来确定:

address(i_{0},i_{2},...,i_{n-1}) = base\_address+\sum_{k=0}^{n-1}(i_{k}\cdot stride(k)) 

        其中,i_{0}是第k维的索引,stride(k)是第k维的步长(就是第k维的数据在存储时,相邻数据在底层线性存储时相隔的数据数),base\_address是张量底层数据存储的起始地址。

        对于一个连续的张量,其stride应该符合从最内层维度(第0维度)到最外层维度递减的模式。更准确地说,如果一个张量有n个维度,并且每个维度的大小是s_{0},s_{2},...,s_{n-1},那么其连续性可以使用下面的方式判定:

stride(n-1)=1

 stride(k)=stride(k+1)\times s_{k+1} \ for\ k = n-2,n-1,...,0

        当你创建一个新的张量时,默认情况下,它是连续的,这意味着它的元素在内存中是按照顺序存储的。可以通过size()方法和stride()方法获得一个张量的形状和步长,利用公式判断是否连续;也可以使用storage()方法,直接获得一个张量在底层的线性存储结果;最方便的是使用.is_contiguous()方法,他会直接返回一个布尔值。

import torch

x = torch.tensor([[1, 2], [3, 4]])
print(list(x.size()))
print(x.stride())
print(x.storage())
print("Is x contiguous?", x.is_contiguous()) 

输出:
[2, 2]
(2, 1)
 1
 2
 3
 4
[torch.LongStorage of size 4]
Is x contiguous? True

        为什么会出现非连续张量呢?在PyTorch中,非连续张量的出现往往与张量视图(tensor views)的概念密切相关。张量视图允许一个新张量作为原张量的视图存在,其中新张量与其原张量共享相同的底层数据。这种设计旨在避免显式的数据复制,从而实现快速且内存高效的操作(例如切片、转置)。

        既然两个不同的张量共享底层存储,如果其中一个张量是连续的,另一个必然是不连续的,

import torch

x = torch.tensor([[1, 2], [3, 4]])
print(list(x.size()))
print(x.stride())
print(x.storage())
print("Is x contiguous?", x.is_contiguous()) 

y = x.t()
print(list(y.size()))
print(y.stride())
print(y.storage())
print("Is y contiguous?", y.is_contiguous())

输出:
[2, 2]
(2, 1)
 1
 2
 3
 4
[torch.LongStorage of size 4]
Is x contiguous? True
[2, 2]
(1, 2)
 1
 2
 3
 4
[torch.LongStorage of size 4]
Is y contiguous? False

        除了t(),Pytorch中有下面这些会返回视图的操作(因此可能出现非连续张量)。

基本的切片和索引, 例如tensor[0, 2:, 1:7:2]
adjoint()
as_strided()
detach()
diagonal()
expand()
expand_as()
movedim()
narrow()
permute()
select()
squeeze()
transpose()
t()
T
H
mT
mH
real
imag
view_as_real()
unflatten()
unfold()
unsqueeze()
view()
view_as()
unbind()
split()
hsplit()
vsplit()
tensor_split()
split_with_sizes()
swapaxes()
swapdims()
chunk()
indices() (仅限稀疏张量)
values() (仅限稀疏张量)

        除此之外,reshape(),reshape_as()和flatten() 既有可能返回张量的视图,也可能返回一个拥有独立存储空间的新张量。这取决于一些特定的条件,具体可见关于这些操作的文章。

Pytorch基础:Tensor的reshape方法_pytorch reshape-CSDN博客icon-default.png?t=N7T8https://chenzhang.blog.csdn.net/article/details/133445832Pytorch基础:Tensor的flatten方法_tensor.flatten-CSDN博客icon-default.png?t=N7T8https://chenzhang.blog.csdn.net/article/details/136570774        最后值得一提的是,contiguous()方法能返回一个连续的张量,如果原张量已连续,则会返回原张量。

import torch

x = torch.tensor([[1, 2], [3, 4]])
print(list(x.size()))
print(x.stride())
print(x.storage())
print("Is x contiguous?", x.is_contiguous())

y = x.t()

z = y.contiguous()
print(list(z.size()))
print(z.stride())
print(z.storage())
print("Is z contiguous?", z.is_contiguous())


输出:
[2, 2]
(2, 1)
 1
 2
 3
 4
[torch.LongStorage of size 4]
Is x contiguous? True
[2, 2]
(2, 1)
 1
 3
 2
 4
[torch.LongStorage of size 4]
Is z contiguous? True
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

日晨难再

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值