torch.contiguous()函数用法

torch.contiguous()函数用于确保张量在内存中是连续存储的,这对于需要连续内存空间的操作如view至关重要。当张量经过transpose等操作导致非连续存储时,必须先调用contiguous(),否则view()等操作会失败。例如,在尝试将非连续张量t2.view(12,1)时会报错,而转换为连续张量t3后则可以成功执行view操作。
摘要由CSDN通过智能技术生成

在看代码的时候发现了torch.contiguous()这个函数,那么它有什么用途呢?
1)背景知识
首先得知道一个tensor的shape和stride的区别。以二维矩阵为例,shape = [row, column]是指几行乘几列,stride = [stride1, stride2]分别是指到下一行需要跳过几个元素,到下一列需要跳过几个元素。由于python的底层是C实现的,遵从行优先的原则。
2)什么时候要用到contiguous
那么什么时候能够用得上contiguous操作呢?
在使用view操作的时候需要连续内存空间的tensor,如果当前的tensor经过了transpose等改变stride的操作,那么需要对这个tensor进行contiguous,然后才能执行view操作。

In [40]: t = torch.arange(1,13).reshape(3,4)

In [41]: t
Out[41]:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
# t.stride() = (4,1)每隔4个元素到下一行,每隔1个元素到下一列
In [43]: t.stride()
Out[43]: (4, 1)
# 将行和列交换位置,transpose
In [44]: t2 = t.transpose(1,0)

In [45]: t2
Out[45]:
tensor([[ 1,  5,  9],
        [ 2,  6, 10],
        [ 3,  7, 11],
        [ 4,  8, 12]])
# t2.stride() = (1,4)每隔1个元素到下一行,每隔4个元素到下一列(在原来t的基础上)
In [46]: t2.stride()
Out[46]: (1, 4)

In [47]: t.is_contiguous()
Out[47]: True

In [48]: t2.is_contiguous()
Out[48]: False
# 无论是t还是t2还是t3,摊平展开之后的结果都是一样的,flatten
In [49]: t.flatten()
Out[49]: tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

In [50]: t2.flatten()
Out[50]: tensor([ 1,  5,  9,  2,  6, 10,  3,  7, 11,  4,  8, 12])

In [51]: t3 = t2.contiguous()

In [52]: t3.flatten()
Out[52]: tensor([ 1,  5,  9,  2,  6, 10,  3,  7, 11,  4,  8, 12])
# 非连续位置存储的元素在进行view操作的时候会报错,t2.view(12,1)报错
In [54]: t2.view(12,1)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-54-52b9424f2a69> in <module>
----> 1 t2.view(12,1)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [55]: t.view(12,1)
Out[55]:
tensor([[ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12]])
# t2.contiguous()变为连续顺序存储,结果就不会报错,可以进行view操作
In [56]: t3.view(12,1)
Out[56]:
tensor([[ 1],
        [ 5],
        [ 9],
        [ 2],
        [ 6],
        [10],
        [ 3],
        [ 7],
        [11],
        [ 4],
        [ 8],
        [12]])
  • 6
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值