PyTorch TORCH.AS_STRIDED

torch.as_strided(input, size, stride, storage_offset=0)→ Tensor

1. 官方文档

Create a view of an existing torch.Tensor input with specified size, stride and storage_offset.

parameters

  • input (Tensor) – the input tensor.
  • size (tuple or ints) – the shape of the output tensor
  • stride (tuple or ints) – the stride of the output tensor
  • storage_offset (int, optional) – the offset in the underlying storage of the output tensor

WARNING

More than one element of a created tensor may refer to a single memory location. As a result, in-place operations (especially ones that are vectorized) may result in incorrect behavior. If you need to write to the tensors, please clone them first.

Many PyTorch functions, which return a view of a tensor, are internally implemented with this function. Those functions, like torch.Tensor.expand(), are easier to read and are therefore more advisable to use.

e.g.

>>> x = torch.randn(3, 3)
>>> x
tensor([[ 0.9039,  0.6291,  1.0795],
        [ 0.1586,  2.1939, -0.4900],
        [-0.1909, -0.7503,  1.9355]])
>>> t = torch.as_strided(x, (2, 2), (1, 2))
>>> t
tensor([[0.9039, 1.0795],
        [0.6291, 0.1586]])
>>> t = torch.as_strided(x, (2, 2), (1, 2), 1)
tensor([[0.6291, 0.1586],
        [1.0795, 2.1939]])

2. 个人理解

基于指定的参数(sizestrideoffset)在原内存进行扩展生成的Tensor。因为按内存顺序读元素,会造成元素共享,因此如果对此过程生成的Tensor直接in_place修改可能会出现意想不到的错误,建议使用深拷贝。

关于其原理,不要对拘泥于原数据的行列关系,按照对应的stride从内存中读数即可。如下图所示,从4x4的源数据中按要求抽取对应的元素。

注意,这里起点为1,W,H方向上都是基于该起点进行。如果加了offset,则是将起点进行offset,见后续代码测试结果。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8eiIiw8s-1633252597768)
在这里插入图片描述

测试代码:

import torch

a = torch.tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12],
                  [13, 14, 15, 16]])

b = torch.as_strided(a, (2, 2), (0, 0))
c = torch.as_strided(a, (2, 2), (0, 1))
d = torch.as_strided(a, (2, 2), (1, 0))
e = torch.as_strided(a, (2, 2), (2, 0))
f = torch.as_strided(a, (2, 2), (0, 2))
g = torch.as_strided(a, (2, 2), (1, 1))
h = torch.as_strided(a, (2, 2), (2, 1))
i = torch.as_strided(a, (2, 2), (1, 2))
j = torch.as_strided(a, (2, 2), (2, 1), 1)
k = torch.as_strided(a, (2, 2), (1, 2), 1)

print(' >> input data: {}'.format(a))
print(' >> Test stride: (0, 0), result: {}'.format(b))
print(' >> Test stride: (0, 1), result: {}'.format(c))
print(' >> Test stride: (1, 0), result: {}'.format(d))
print(' >> Test stride: (2, 0), result: {}'.format(e))
print(' >> Test stride: (0, 2), result: {}'.format(f))
print(' >> Test stride: (1, 1), result: {}'.format(g))
print(' >> Test stride: (2, 1), result: {}'.format(h))
print(' >> Test stride: (1, 2), result: {}'.format(i))
print(' >> Test stride: (2, 1), offset: 1, result: {}'.format(h))
print(' >> Test stride: (1, 2), offset: 1, result: {}'.format(i))

测试结果:

 >> input data: tensor([[ 1,  2,  3,  4],
                        [ 5,  6,  7,  8],
                        [ 9, 10, 11, 12],
                        [13, 14, 15, 16]])
 >> Test stride: (0, 0), result: tensor([[1, 1],
                                         [1, 1]])
 >> Test stride: (0, 1), result: tensor([[1, 2],
                                         [1, 2]])
 >> Test stride: (1, 0), result: tensor([[1, 1],
                                         [2, 2]])
 >> Test stride: (2, 0), result: tensor([[1, 1],
                                         [3, 3]])
 >> Test stride: (0, 2), result: tensor([[1, 3],
                                         [1, 3]])
 >> Test stride: (1, 1), result: tensor([[1, 2],
                                         [2, 3]])
 >> Test stride: (2, 1), result: tensor([[1, 2],
                                         [3, 4]])
 >> Test stride: (1, 2), result: tensor([[1, 3],
                                         [2, 4]])
 >> Test stride: (2, 1), offset: 1, result: tensor([[2, 3],
                                                    [4, 5]])
 >> Test stride: (1, 2), offset: 1, result: tensor([[2, 4],
                                                    [3, 5]])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值