pytorch基础学习(五) 数据处理(二)

本篇主要介绍pytorch中tensor的基本操作如:对tensor进行flatten操作, 对tensor进行拼接,tensor的broadcast(广播)机制.

1. 对tensor进行flatten操作

flatten顾名思义就是平展,将一个tensor由高维(rank)变为1维,元素的数量保持不变,这在深度学习中很常用,当输入是一个图像时候,图像的维度(rank)是3,当网络的输入层只能输入一维的数据时(如全连接层),flatten操作就显得非常有用了.

下面我们说两种flatten的实现方式:使用上篇的squeeze,reshape函数间接实现;使用tensor自带的flatten函数实现.

1. 使用squeeze,reshape函数间接实现

实现代码如下,我们可以写一个flatten函数,该函数的功能是输入一个tensor,将它维度变为1输出.

def flatten(t):
    t = t.reshape(1, -1)
    t = t.squeeze()
    return t

函数的第一行先利用reshape函数将tensor变为2维度,具体可以参考上篇. 此时可以注意到,第一个axis的长度为1,因此第二行使用squeeze函数将长度为1的axis去掉,这时候tensor的维度自然而然就变成1啦!

举个栗子吧:

t = torch.tensor([
    [1, 1, 1, 1],
    [2, 2, 2, 2],
    [3, 3, 3, 3]
], dtype=torch.float32)
print(t.shape)
print(flatten(t))
print(flatten(t).shape)
output:
torch.Size([3, 4])
tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.])
torch.Size([12])

可以看到,使用我们自己写的flatten函数,将tensor变成了1维,但这很不方便欸 ,有自带的干嘛不用呢,用它!

2. 直接使用tensor的flatten函数

tensor是有flatten函数的,话不多说,直接举个栗子:

t1 = torch.ones(4, 4)
t2 = torch.ones(4, 4) * 2
t3 = torch.ones(4, 4) * 3
t = torch.stack((t1, t2, t3)) 
t = t.reshape(3, 1, 4, 4) 
print(t.flatten(start_dim=1))
print(t.flatten(start_dim=1).shape)
print(t.reshape(t.shape[0], -1))
print(t.flatten().shape)
output:
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]])
torch.Size([3, 16])
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]])
torch.Size([48])

可以看到flatten有个参数start_dim,表示从start_dim到最后一个维度都做平展操作,因此t.flatten(start_dim=1)后,就只有第0维保持不变,其它维度做了flatten,变成了一维,从而shape变为torch.Size([3, 16]),这在深度学习中,向全连接层输入时经常会使用到.

当flatten没有参数时,默认将整个tensor进行flatten操作,即start_dim=1,因此经过t.flatten(),其shape变为torch.Size([48]).

2. 对tensor进行拼接

将2个甚至更多的tensor进行拼接是经常要使用到的功能,pytorch中对tensor拼接常用的函数为cat,举个栗子吧:

t1 = torch.tensor([
    [1, 2],
    [3, 4]
])
t2 = torch.tensor([
    [5, 6],
    [7, 8]
])
print(torch.cat((t1, t2), dim=0))
print(torch.cat((t1, t2), dim=0).shape)
print(torch.cat((t1, t2), dim=1))
print(torch.cat((t1, t2), dim=1).shape)
output:
torch.Size([12])
tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])
torch.Size([4, 2])
tensor([[1, 2, 5, 6],
        [3, 4, 7, 8]])
torch.Size([2, 4])

可以看到,cat中的dim参数决定了拼接的维度.
当dim=0时,在第0维(第1个axis)进行拼接 ,其余维度长度不变,因此shape变为torch.Size([4, 2]).
当dim=1时,在第1维(第2个axis)进行拼接 ,其余维度长度不变,因此shape变为torch.Size([2, 4]).
为了达成拼接的目的,很容易我们可以看出,对于要拼接的tensor(可以不止2个),除了需要拼接的维度,其余维度的长度必须保持相同,否则会引起错误.

3. tensor的broadcast(广播)机制

在数学运算中,两个形状不同的矩阵进行加减运算显然是不行的,但对于tensor,在某些形况下是完全可以的,这得益于pytorch中tensor的broadcast机制.

举个栗子:

t1 = torch.tensor([
    [1, 2],
    [3, 4]
])
t2 = torch.tensor([
    [9, 8],
    [7, 6]
])
print(t1 + t2)
print(t1 + 2)
print(t1 + torch.tensor(
    np.broadcast_to(2, t1.shape),
    dtype=torch.int32
))  # equal to last line
output:
tensor([[10, 10],
        [10, 10]])
tensor([[3, 4],
        [5, 6]])
tensor([[3, 4],
        [5, 6]])

从倒数第二个print那里,我们惊讶的发现**print(t1 + 2)**竟然也可以运算!这得益于broadcast机制,等同于最后一个print中的语句,pytorch将2自动扩充成了与t1形状相同的tensor,这样当然就可以运算啦.

那么有个大胆的想法,除了常数,不同形状的tensor是否也可以这样操作,举个栗子试一试:

t1 = torch.tensor([
    [1, 2],
    [3, 4]
])
t3 = torch.tensor([2, 4])
print(t1 + t3)
print(t1 + torch.tensor(
    np.broadcast_to(t3.numpy(), t1.shape),
    dtype=torch.int32
))  # equal to last line
output:
tensor([[3, 6],
        [5, 8]])
tensor([[3, 6],
        [5, 8]])

哈哈,果然是可以的,pytorch将t3自动扩充成了与t1形状相同的tensor(将t3又复制了一行)再与之运算. 这个特性很有意思,很方便我们书写代码.

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值