pytorch学习笔记二:张量操作

一、张量拼接与切分

1.1 torch.cat()

在这里插入图片描述
功能:将张量按维度dim进行拼接

  • tensor:张量序列
  • dim:要拼接的维度

举例:

t = torch.ones(2,3)
t_0 = torch.cat([t,t],dim=0)
t_1 = torch.cat([t,t,t],dim=1)
print('t_0:{} shape:{}\nt_1:{} shape:{}'.format(t_0,t_0.shape,t_1,t_1.shape))

结果:

t_0:tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]) shape:torch.Size([4, 3])
t_1:tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1.]]) shape:torch.Size([2, 9])
1.2 torch.stack()

在这里插入图片描述
功能:在新创建的维度dim进行拼接

t = torch.ones((2,3))
# t只有0,1两个维度,会在2上创建一个新的维度
t_stack = torch.stack([t,t],dim = 2)
print('\nt_stack:{} shape:{}'.format(t_stack,t_stack.shape))
t_stack:tensor([[[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[1., 1.],
         [1., 1.],
         [1., 1.]]]) shape:torch.Size([2, 3, 2])

如果在dim=0上拼接:

t_stack = torch.stack([t,t],dim = 0)
t_stack:tensor([[[1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.]]]) shape:torch.Size([2, 2, 3])

总结:cat不会拓展张量维度,stack会拓展张量维度

1.3 torch.chunk()

在这里插入图片描述
功能:将张量按维度dim进行平均切分
返回值:张量列表
注意:若不能整除,最后一份张量小于其他张量

  • input:要切分的张量
  • chunks:要切分的份数
  • dim:要切分的维度

举例:

a = torch.ones((2,5))
list_of_tensors = torch.chunk(a,dim=1,chunks=2)
for idx,t in enumerate(list_of_tensors):
    print('第{}个张量:{},shape is {}'.format(idx+1,t,t.shape))

结果:

1个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]),shape is torch.Size([2, 3])2个张量:tensor([[1., 1.],
        [1., 1.]]),shape is torch.Size([2, 2])

chunks = 3,计算方式:7/3=2.x,向上取整得3,两组为3,剩下的为一组

a = torch.ones((2,7))
list_of_tensors = torch.chunk(a,dim=1,chunks=3)
1个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]),shape is torch.Size([2, 3])2个张量:tensor([[1., 1., 1.],
        [1., 1., 1.]]),shape is torch.Size([2, 3])3个张量:tensor([[1.],
        [1.]]),shape is torch.Size([2, 1])
1.4 torch.split()

在这里插入图片描述
功能:将张量按维度dim进行切分
返回值:张量列表

  • tensor: 要切分的张量
  • split_size_or_sections : 为int时,表示每一份的长度;为list时,按list元素切分
  • dim : 要切分的维度

举例:
split_size_or_sections为list时

t = torch.ones((2,5))
list_of_tensors = torch.split(t, [2,1,2], dim=1)#list之和要等于dim维度的长度
# list_of_tensors = torch.split(t,2,dim=1)
for idx, t in enumerate(t,2,dim=1):
    print('第{}个张量:{},shape is {}'.format(idx+1,t.shape))

结果:

1个张量:tensor([[1., 1.],
        [1., 1.]]),shape is torch.Size([2, 2])2个张量:tensor([[1.],
        [1.]]),shape is torch.Size([2, 1])3个张量:tensor([[1., 1.],
        [1., 1.]]),shape is torch.Size([2, 2])

split_size_or_sections为int时:

list_of_tensors = torch.split(t,2,dim=1)

结果:

1个张量:tensor([[1., 1.],
        [1., 1.]]),shape is torch.Size([2, 2])2个张量:tensor([[1., 1.],
        [1., 1.]]),shape is torch.Size([2, 2])3个张量:tensor([[1.],
        [1.]]),shape is torch.Size([2, 1])

二、张量索引

2.1 torch.split()

在这里插入图片描述
功能:在维度dim上,按index索引数据
返回值:依index索引数据拼接的张量

  • input: 要索引的张量
  • dim: 要索引的维度
  • index : 要索引数据的序号

举例:

t = torch.randint(0, 9, size=(3,3))
idx = torch.tensor([0,2], dtype=torch.long)#必须是long
t_select = torch.index_select(t,dim=0,index=idx)
print('t:\n{}\nt_select:\n{}'.format(t,t_select))

结果:

t:
tensor([[1, 8, 4],
        [6, 3, 4],
        [3, 3, 5]])
t_select:
tensor([[1, 8, 4],
        [3, 3, 5]])
2.2 torch.masked_select()

在这里插入图片描述
功能:按mask中的True进行索引
返回值:一维张量

  • input: 要索引的张量
  • mask: 与input同形状的布尔类型张量

举例:

t = torch.randint(0,9,size=(3,3))
#t>=5的元素返回true
mask = t.ge(5) #ge is mean greater than or equal/
t_select = torch.masked_select(t,mask)
print('t:\n{}\nmask:\n{}\nt_select:\n{}'.format(t,mask,t_select))

结果:

tensor([[5, 6, 1],
        [6, 3, 2],
        [0, 2, 8]])
mask:
tensor([[ True,  True, False],
        [ True, False, False],
        [False, False,  True]])
t_select:
tensor([5, 6, 6, 8])

三、张量变换

3.1 torch.reshape()

在这里插入图片描述
功能:变换张量形状
注意事项:当张量在内存中是连续时,新张量与input共享数据内存

  • input: 要变换的张量
  • shape: 新张量的形状

举例:

t = torch.randperm(8)#0 to n - 1的随机排列
t_reshape = torch.reshape(t,(2,4))
print('t:{}\nt_reshape:\n{}'.format(t,t_reshape))

t[0] = 1024
print('t:{}\nt_reshape:\n{}'.format(t,t_reshape))
#内存地址相同,为什么要.data???
print('t.data内存地址:{}'.format(id(t.data)))
print('t_reshape.data内存地址:{}'.format(id(t_reshape.data)))

结果:

t:tensor([4, 1, 7, 6, 5, 0, 2, 3])
t_reshape:
tensor([[4, 1, 7, 6],
        [5, 0, 2, 3]])
t:tensor([1024,    1,    7,    6,    5,    0,    2,    3])
t_reshape:
tensor([[1024,    1,    7,    6],
        [   5,    0,    2,    3]])
t.data内存地址:109526516584
t_reshape.data内存地址:109526516584
3.2 torch.transpose()

在这里插入图片描述
功能:交换张量的两个维度

  • input: 要变换的张量
  • dim0: 要交换的维度
  • dim1: 要交换的维度

举例:

t = torch.rand((2,3,4))
t_transpose = torch.transpose(t,dim0=1,dim1=2)#交换第一个和第二个维度
print('t shape:{}\nt_transpose shape:{}'.format(t.shape,t_transpose.shape))

结果:

t shape:torch.Size([2, 3, 4])
t_transpose shape:torch.Size([2, 4, 3])
3.3 torch.t()

功能: 2维张量转置,对矩阵而言,等价于
torch.transpose(input, 0, 1)

3.4 torch.transpose()

在这里插入图片描述
功能: 压缩长度为1的维度(轴)

  • dim: 若为None,移除所有长度为1的轴;若指定维度,当且仅当该轴长度为1时,可以被移除;

举例:

t = torch.rand((1,2,3,1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t,dim=0)
t_1 = torch.squeeze(t,dim=1)
print(t.shape)
print(t_sq.shape)
print(t_0)
print(t_1.shape)

结果:

torch.Size([1, 2, 3, 1])
torch.Size([2, 3])
tensor([[[0.8293],
         [0.0632],
         [0.1675]],

        [[0.9306],
         [0.4120],
         [0.4964]]])
torch.Size([1, 2, 3, 1])
3.5 torch.unsqueeze()

在这里插入图片描述
功能:依据dim扩展维度

  • dim: 扩展的维度

四、张量的数学运算

torch.add()
在这里插入图片描述
功能:逐元素计算 input+alpha×other
• input: 第一个张量
• alpha: 乘项因子
• other: 第二个张量

举例

t_0 = torch.randn((3,3))#标准正态分布
t_1 = torch.ones_like(t_0)
t_add = torch.add(t_0,10,t_1)#10*t_1+t_0

print('t_0:\n{}\nt_1:\n{}\nt_add_10:\n{}'.format(t_0,t_1,t_add))

结果:

t_0:
tensor([[ 1.1776, -1.2565,  0.5444],
        [-0.5204, -2.4183, -0.5783],
        [-0.9523,  0.7578,  0.5852]])
t_1:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
t_add_10:
tensor([[11.1776,  8.7435, 10.5444],
        [ 9.4796,  7.5817,  9.4217],
        [ 9.0477, 10.7578, 10.5852]])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值