参考链接: expand(*sizes)
参考链接: expand_as(other)
说明:
expand(*sizes)方法返回一个张量,
他自动将原来的张量所有长度为1的维度扩展成所需要的长度,
注意只能扩展长度为1的维度(singleton dimensions),
如果传入-1,表示该长度为1的维度不变.
此外,我们还可以对维度的数量进行扩增,
维度扩增的部分放在最前面,详情见如下代码实验
注意:扩展一个张量并不意味在内存中分配新的空间,
它只会返回一个view视图,因此会有多个引用指向同一个内存区域,
如果要做原地操作的话要注意,如果要对张量进行写入操作,
需要先对他们进行拷贝.
实验1: expand(*sizes)
(base) PS C:\Users\chenxuqi> python
Python 3.7.4 (default, Aug 9 2019, 18:34:13) [MSC v.1915 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.tensor([[1], [2], [3]])
>>> x
tensor([[1],
[2],
[3]])
>>> x.size()
torch.Size([3, 1])
>>> x.stride()
(1, 1)
>>> y = x.expand(3, 4)
>>> y
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
>>> y.stride()
(1, 0)
>>>
>>>
>>> y = x.expand(-1, 4) # -1 means not changing the size of that dimension
>>> y
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
>>> y.stride()
(1, 0)
>>> y.shape
torch.Size([3, 4])
>>>
>>>
>>>
>>> x = torch.ones(1,2,3,4,5,6)
>>> x.shape
torch.Size([1, 2, 3, 4, 5, 6])
>>> x.stride()
(720, 360, 120, 30, 6, 1)
>>> x.expand(7, 2, 3, 4, 5, 6).shape
torch.Size([7, 2, 3, 4, 5, 6])
>>> x.expand(7, 2, 3, 4, 5, 6).stride()
(0, 360, 120, 30, 6, 1)
>>>
>>>
>>>
>>>
>>>
>>> x = torch.ones(1,2,1,4,1,6,1,1,1,8)
>>> x.shape
torch.Size([1, 2, 1, 4, 1, 6, 1, 1, 1, 8])
>>> x.stride()
(384, 192, 192, 48, 48, 8, 8, 8, 8, 1)
>>> x.expand(7, 2, 3, 4, 5, 6,9,10,11,8).shape
torch.Size([7, 2, 3, 4, 5, 6, 9, 10, 11, 8])
>>> x.expand(7, 2, 3, 4, 5, 6,9,10,11,8).stride()
(0, 192, 0, 48, 0, 8, 0, 0, 0, 1)
>>>
>>>
>>>
>>>
>>>
>>>
>>> x.shape
torch.Size([1, 2, 1, 4, 1, 6, 1, 1, 1, 8])
>>> x.stride()
(384, 192, 192, 48, 48, 8, 8, 8, 8, 1)
>>> x.expand(-1, 2, -1, 4, 5, 6,9,10,11,8).shape
torch.Size([1, 2, 1, 4, 5, 6, 9, 10, 11, 8])
>>> x.expand(-1, 2, -1, 4, 5, 6,9,10,11,8).stride()
(384, 192, 192, 48, 0, 8, 0, 0, 0, 1)
>>>
>>>
>>>
(base) PS C:\Users\chenxuqi> python
Python 3.7.4 (default, Aug 9 2019, 18:34:13) [MSC v.1915 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.ones(3,1,5,6)
>>> x.shape
torch.Size([3, 1, 5, 6])
>>> x.stride()
(30, 30, 6, 1)
>>> y = x.expand(7, 2, 3, 4, 5, 6)
>>> y.shape
torch.Size([7, 2, 3, 4, 5, 6])
>>> y.stride()
(0, 0, 30, 0, 6, 1)
>>> x.shape
torch.Size([3, 1, 5, 6])
>>> x.stride()
(30, 30, 6, 1)
>>>
>>>
>>>
错误使用举例:
(base) PS C:\Users\chenxuqi> python
Python 3.7.4 (default, Aug 9 2019, 18:34:13) [MSC v.1915 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.ones(3,1,5,6)
>>> x.shape
torch.Size([3, 1, 5, 6])
>>> x.stride()
(30, 30, 6, 1)
>>> y = x.expand(7, 2, 3, 4, 5, 6)
>>> y.shape
torch.Size([7, 2, 3, 4, 5, 6])
>>> y.stride()
(0, 0, 30, 0, 6, 1)
>>>
>>>
>>>
...
>>> x.shape
torch.Size([3, 1, 5, 6])
>>> z = x.expand(3, 1, 5, 6, 4)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: The expanded size of the tensor (4) must match the existing size (6) at non-singleton dimension 4. Target sizes: [3, 1, 5, 6, 4]. Tensor sizes: [3, 1, 5, 6]
>>>
>>>
>>>
说明:
expand_as(other)方法的作用是对原张量的维度进行扩展,
使其形状和给定的张量相同,
该方法类似于expand(*sizes),
self.expand_as(other) 等效于 self.expand(other.size())
实验2: expand_as(other)
(base) PS C:\Users\chenxuqi> python
Python 3.7.4 (default, Aug 9 2019, 18:34:13) [MSC v.1915 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>>
>>> x = torch.ones(4, 512, 38, 38)
>>> x.shape
torch.Size([4, 512, 38, 38])
>>> x.stride()
(739328, 1444, 38, 1)
>>>
>>>
>>> weight = torch.zeros(1, 512, 1, 1)
>>> weight.shape
torch.Size([1, 512, 1, 1])
>>> weight.stride()
(512, 1, 1, 1)
>>>
>>> y = weight.expand_as(x)
>>> y.shape
torch.Size([4, 512, 38, 38])
>>> y.stride()
(0, 1, 0, 0)
>>>
>>> weight.shape
torch.Size([1, 512, 1, 1])
>>> weight.stride()
(512, 1, 1, 1)
>>>
>>>