torch.as_strided(input, size, stride, storage_offset=0)
→ Tensor
1. 官方文档
Create a view of an existing torch.Tensor input
with specified size
, stride
and storage_offset
.
parameters
- input (Tensor) – the input tensor.
- size (tuple or ints) – the shape of the output tensor
- stride (tuple or ints) – the stride of the output tensor
- storage_offset (int, optional) – the offset in the underlying storage of the output tensor
WARNING
More than one element of a created tensor may refer to a single memory location. As a result, in-place operations (especially ones that are vectorized) may result in incorrect behavior. If you need to write to the tensors, please clone them first.
Many PyTorch functions, which return a view of a tensor, are internally implemented with this function. Those functions, like
torch.Tensor.expand()
, are easier to read and are therefore more advisable to use.
e.g.
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 0.9039, 0.6291, 1.0795],
[ 0.1586, 2.1939, -0.4900],
[-0.1909, -0.7503, 1.9355]])
>>> t = torch.as_strided(x, (2, 2), (1, 2))
>>> t
tensor([[0.9039, 1.0795],
[0.6291, 0.1586]])
>>> t = torch.as_strided(x, (2, 2), (1, 2), 1)
tensor([[0.6291, 0.1586],
[1.0795, 2.1939]])
2. 个人理解
基于指定的参数(size
、stride
、offset
)在原内存进行扩展生成的Tensor
。因为按内存顺序读元素,会造成元素共享,因此如果对此过程生成的Tensor
直接in_place
修改可能会出现意想不到的错误,建议使用深拷贝。
关于其原理,不要对拘泥于原数据的行列关系,按照对应的stride
从内存中读数即可。如下图所示,从4x4
的源数据中按要求抽取对应的元素。
注意,这里起点为1,W,H方向上都是基于该起点进行。如果加了offset
,则是将起点进行offset
,见后续代码测试结果。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8eiIiw8s-1633252597768)
测试代码:
import torch
a = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]])
b = torch.as_strided(a, (2, 2), (0, 0))
c = torch.as_strided(a, (2, 2), (0, 1))
d = torch.as_strided(a, (2, 2), (1, 0))
e = torch.as_strided(a, (2, 2), (2, 0))
f = torch.as_strided(a, (2, 2), (0, 2))
g = torch.as_strided(a, (2, 2), (1, 1))
h = torch.as_strided(a, (2, 2), (2, 1))
i = torch.as_strided(a, (2, 2), (1, 2))
j = torch.as_strided(a, (2, 2), (2, 1), 1)
k = torch.as_strided(a, (2, 2), (1, 2), 1)
print(' >> input data: {}'.format(a))
print(' >> Test stride: (0, 0), result: {}'.format(b))
print(' >> Test stride: (0, 1), result: {}'.format(c))
print(' >> Test stride: (1, 0), result: {}'.format(d))
print(' >> Test stride: (2, 0), result: {}'.format(e))
print(' >> Test stride: (0, 2), result: {}'.format(f))
print(' >> Test stride: (1, 1), result: {}'.format(g))
print(' >> Test stride: (2, 1), result: {}'.format(h))
print(' >> Test stride: (1, 2), result: {}'.format(i))
print(' >> Test stride: (2, 1), offset: 1, result: {}'.format(h))
print(' >> Test stride: (1, 2), offset: 1, result: {}'.format(i))
测试结果:
>> input data: tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])
>> Test stride: (0, 0), result: tensor([[1, 1],
[1, 1]])
>> Test stride: (0, 1), result: tensor([[1, 2],
[1, 2]])
>> Test stride: (1, 0), result: tensor([[1, 1],
[2, 2]])
>> Test stride: (2, 0), result: tensor([[1, 1],
[3, 3]])
>> Test stride: (0, 2), result: tensor([[1, 3],
[1, 3]])
>> Test stride: (1, 1), result: tensor([[1, 2],
[2, 3]])
>> Test stride: (2, 1), result: tensor([[1, 2],
[3, 4]])
>> Test stride: (1, 2), result: tensor([[1, 3],
[2, 4]])
>> Test stride: (2, 1), offset: 1, result: tensor([[2, 3],
[4, 5]])
>> Test stride: (1, 2), offset: 1, result: tensor([[2, 4],
[3, 5]])