Tensor的操作:拼接,切分,索引,变换

拼接

torch.cat()

tensor1 = torch.ones(size=(2,3))
tensor_cat1 = torch.cat(tensors=[tensor1,tensor1],dim=0)
tensor_cat2 = torch.cat(tensors=[tensor1,tensor1],dim=1)
tensor_cat3 = torch.cat(tensors=[tensor1,tensor1,tensor1],dim=1)
print(tensor_cat1)
print(tensor_cat1.shape) #(4,3)
print(tensor_cat2)
print(tensor_cat2.shape) #(2,6)
print(tensor_cat3)
print(tensor_cat3.shape) #(2,9)

troch.stack()

tensor2 = torch.ones(size=(4,3))
tensor2_stack1 = torch.stack(tensors=[tensor2,tensor2],dim=2)
tensor2_stack2 = torch.stack(tensors=[tensor2,tensor2],dim=0)
tensor2_stack3 = torch.stack(tensors=[tensor2,tensor2,tensor2],dim=0)
print(tensor2_stack1)
print(tensor2_stack1.shape) # (4,3,2)
print(tensor2_stack2.shape) # (2,4,3)
print(tensor2_stack3.shape) # (3,4,3)

cat:不会扩张tensor的维度
stack:会扩张tensor的维度

切分

torch.chunk()

  • 功能:将tensor按维度dim进行平均切分
  • 返回值:tensor列表
  • 参数chunks:要切分的分数
'''
会返回一个包含2个tensor的列表,
	第一个tensor的shape是(2,3)
	第二个tensor的shape是(2,2)
'''
tensor3 = torch.ones(size=(2,5))
tensor3_chunk1 = torch.chunk(input=tensor3,chunks=2,dim=1)
for idx,tensor in enumerate(tensor3_chunk1):
    print(idx,tensor,tensor.shape)

torch.split()

  • 参数split_size_or_sections: 可以为int,也可以为list
 tensor4 = torch.ones(size=(2,5))
 tensor4_split1 = torch.split(tensor=tensor4,split_size_or_sections=2,dim=1)
 for idx,tensor in enumerate(tensor4_split1):
     print(idx,tensor,tensor.shape) #[(2,2),(2,2),(2,1)]
 tensor4_split2 = torch.split(tensor=tensor4,split_size_or_sections=[2,1,2],dim=1)
 for idx,tensor in enumerate(tensor4_split2):
     print(idx,tensor,tensor.shape) # [(2,2),(2,1),(2,2)]

索引

torch.index_select()

'''
选择了原始tensor的第0行和第2行
'''
tensor5 = torch.randint(low=0,high=9,size=(3,3))
idx = torch.tensor(data=[0,2],dtype=torch.long)
tensor5_index_select1 = torch.index_select(input=tensor5,dim=0,index=idx) #dim为0表示从行的角度
print(tensor5)
print(tensor5_index_select1)

在这里插入图片描述

torch.masked_select()

  • 返回值:一维tensor
    tensor6 = torch.randint(low=0,high=9,size=(3,3))
    mask = tensor6.ge(5) # ge: 大于等于 gt: 大于 le lt
    tensor6_masked_select = torch.masked_select(input=tensor6,mask=mask)
    print(tensor6)
    print(mask)
    print(tensor6_masked_select)

在这里插入图片描述

变换

torch.reshape()

当tensor在内存中是连续时,新tensor与input共享内存(一个中的元素的值被改变,另一个中的这个元素的值也会被改变)

tensor7 = torch.randperm(n=8)
tensor7_reshape1 = torch.reshape(input=tensor7, shape=(2,4))
tensor7_reshape2 = torch.reshape(input=tensor7,shape=(-1,2,2))
print(tensor7)
print(tensor7_reshape1)
print(tensor7_reshape2)
print(tensor7_reshape2.shape)

torch.transpose()

tensor8 = torch.rand(size=(2,3,4))
tensor8_transpose1 = torch.transpose(input=tensor8,dim0=1,dim1=2)
print(tensor8.shape)
print(tensor8_transpose1.shape)

torch.t()

针对二维tensor

 tensor9 = torch.rand(size=(2,3))
 tensor9_t1 = torch.t(input=tensor9)
 print(tensor9.shape)
 print(tensor9_t1.shape)

torch.squeeze()

如果不指定dim,那么input中所有为1的dim都会被移除
如果指定了dim,并且指定的dim为1,那么移除
如果指定了dim,但是指定的dim不为1,那么不移除

tensor10 = torch.rand(size=(1,2,3,1))
tensor10_squeeze1 = torch.squeeze(input=tensor10)
tensor10_squeeze2 = torch.squeeze(input=tensor10,dim=0)
tensor10_squeeze3 = torch.squeeze(input=tensor10,dim=1)
print(tensor10_squeeze1.shape,tensor10_squeeze2.shape,tensor10_squeeze3.shape)

troch.unsqueeze()

 tensor11 = torch.rand(size=(2,3))
 tensor11_unsqueeze1 = torch.unsqueeze(input=tensor11,dim=0)
 tensor11_unsqueeze2 = torch.unsqueeze(input=tensor11,dim=1)
 tensor11_unsqueeze3 = torch.unsqueeze(input=tensor11,dim=2)
 print(tensor11_unsqueeze1.shape,tensor11_unsqueeze2.shape,tensor11_unsqueeze3.shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值