torch.reshape理解

torch官方文档说明

Returns a tensor with the same data and number of elements as input, but with the specified shape. When possible, the returned tensor will be a view of input. Otherwise, it will be a copy. Contiguous inputs and inputs with compatible strides can be reshaped without copying, but you should not depend on the copying vs. viewing behavior.
返回一个张量,数据和元素的数量和输入相同,但是具有指定的形状

See torch.Tensor.view() on when it is possible to return a view.

A single dimension may be -1, in which case it’s inferred from the remaining dimensions and the number of elements in input.
单个维度可能是-1,这个情况下,是根据剩余维度和输入元素推断。也就是说shape不一定明确指定,可能是函数推断的

Parameters:
input (Tensor) – the tensor to be reshaped
shape (tuple of python:int) – the new shape

参数

Example:

a = torch.arange(4.)
torch.reshape(a, (2, 2))
tensor([[ 0., 1.],
[ 2., 3.]])
b = torch.tensor([[0, 1], [2, 3]])
torch.reshape(b, (-1,))
tensor([ 0, 1, 2, 3])

自己的理解

不得不说官方文档真是一个字都不愿意多解释
找了另一个博主的文章 torch.reshape用法_江南汪的博客-CSDN博客

import torch
a=torch.tensor([[[1,2,3],[4,5,6]],
                [[7,8,9],[10,11,12]]])
print("a的shape:",a.shape)
b=torch.reshape(a,((4,3,1)))
print("b:",b)
print("b的shape:",b.shape)


a的shape: torch.Size([2, 2, 3])
b: tensor([[[ 1],
         [ 2],
         [ 3]],

        [[ 4],
         [ 5],
         [ 6]],

        [[ 7],
         [ 8],
         [ 9]],

        [[10],
         [11],
         [12]]])
b的shape: torch.Size([4, 3, 1])

这不是就一目了然了吗,223 shape的张量,reshape成431的形状,可以看到总共都是12个元素,数据和元素数量不变。具体是怎么组织得到的呢,可以看到无论怎么reshape这些元素的顺序是不变的,也就是1-12,所以其实就是递归一个序列,按照形状截断、重新拼起来成为一个张量
(说起来怎么这么简单,但它就是不给你好好说)

shape参数-1的问题

shape参数有时候不肯给你好好写,也就是会出现-1
比如

out = out.reshape(-1, seq_len * hidden_size) 

这是啥意思,还是看一些具体的例子就懂了

import torch
a=torch.tensor([[[1,2,3],[4,5,6]],
                [[7,8,9],[10,11,12]]])
b=torch.reshape(a,(-1,))
c=torch.reshape(a,(-1,1))
d=torch.reshape(a,((-1,1,1)))
e=torch.reshape(a,((-1,1,1,1)))
f=torch.reshape(a,((-1,3,4)))
a的shape: torch.Size([2, 2, 3])
b的shape: torch.Size([12])
c的shape: torch.Size([12, 1])
d的shape: torch.Size([12, 1, 1])
e的shape: torch.Size([12, 1, 1, 1])
f的shape: torch.Size([1, 3, 4])

也就是说shape的形状有时候不会给你写明,那么-1就是一个占位符,表示这个地方你先别看,从其他维度来推断:it’s inferred from the remaining dimensions and the number of elements in input.
比如说2 2 3的张量reshape成-1 1 1,那就先看后面这个1 1是明确的,一共12个元素,说明第一个维度上-1就是12,也就是说新张量的维度是12 1 1

回到我那个例子,我是给一个(batch_size, seq_len, hidden_size)的张量在做一个展平,于是使用了reshape,那么很显然,我这里的-1就代表batch_size,你说你直接写batch_size不就行了……当然这样写肯定有这样写的好处,那就是绝对不会出错

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值