PyTorch | tensor 的存储(storage、offset 和 stride)
1. Storage
1.1 Storage 概述
- t e n s o r tensor tensor 分为头信息区 ( T e n s o r ) (Tensor) (Tensor)和存储区 ( S t o r a g e ) (Storage) (Storage)。
- 头信息区 ( T e n s o r ) (Tensor) (Tensor)主要保存着 t e n s o r tensor tensor 的形状 ( s i z e ) (size) (size)、步长 ( s t r i d e ) (stride) (stride)、数据类型 ( t y p e ) (type) (type)等信息,而真正的数据则保存成连续数组,存储在存储区 ( S t o r a g e ) (Storage) (Storage)。
- 因为数据动辄成千上万,因此
信息区
元素占用内存较少,主要内存占用取决于 t e n s o r tensor tensor 中元素的数目,即存储区
的大小。 - 不同的
t
e
n
s
o
r
tensor
tensor 的头信息一般不同,但是可能使用相同的
s
t
o
r
a
g
e
storage
storage。
1.2 Storage 操作
-
S
t
o
r
a
g
e
Storage
Storage 查看
import torch t = torch.tensor([[1,2,3],[4,5,6],[7,8,9]]) t.storage()
-
S
t
o
r
a
g
e
Storage
Storage 索引
import torch t = torch.tensor([[1,2,3],[4,5,6],[7,8,9]]) t.storage()[1]
-
S
t
o
r
a
g
e
Storage
Storage 修改
import torch t = torch.tensor([[1,2,3],[4,5,6],[7,8,9]]) t.storage()[1]=10 t.storage()
-
S
t
o
r
a
g
e
Storage
Storage 判断
t o r c h . i s _ s t o r a g e ( o b j ) : torch.is\_storage(obj): torch.is_storage(obj):如果 o b j obj obj 是一个 p y t o r c h s t o r a g e pytorch\ storage pytorch storage 对象,则返回 T r u e True True。
2. offset
storage_offset
: t e n s o r tensor tensor 的第一个元素在 s t o r a g e storage storage 中的索引。
3. stride
- s t r i d e stride stride 是 s t o r a g e storage storage 中对应于 t e n s o r tensor tensor 的相邻维度间第一个索引的跨度,也叫步长。
- 当我们根据下标索引查找 t e n s o r tensor tensor 中的任意元素时,将某维度的下标索引和对应的步长相乘,然后将所有维度乘积相加就可以了。
- 根据
t
e
n
s
e
r
tenser
tenser 中的索引
i
,
j
i,j
i,j 查找
s
t
o
r
a
g
e
storage
storage 中对应索引的公式是
s
t
o
r
a
g
e
_
o
f
f
s
e
t
+
s
t
r
i
d
e
[
0
]
∗
i
+
s
t
r
i
d
e
[
1
]
∗
j
storage\_offset+stride[0]*i+stride[1]*j
storage_offset+stride[0]∗i+stride[1]∗j,因为是从
s
t
o
r
a
g
e
storage
storage 的开头查找,所以
s
t
o
r
a
g
e
_
o
f
f
s
e
t
=
0
storage\_offset=0
storage_offset=0。
示例
:上图是一个 s t o r a g e storage storage,与它对应的 t e n s o r ( [ [ 1.0 , 2.0 , 3.0 ] , [ 4.0 , 5.0 , 6.0 ] ] ) tensor([[1.0,2.0,3.0], [4.0,5.0,6.0]]) tensor([[1.0,2.0,3.0],[4.0,5.0,6.0]]) 如下图所示:
\qquad 那么 t e n s o r tensor tensor 的 s t r i d e = ( 3 , 1 ) stride=(3,1) stride=(3,1),因为从第一行的第一个索引到第二行第一个索引跨度是 3 3 3,从第一列到第二列的跨度是 1 1 1。