torch.cat()函数,torch.stack()函数

相同点

都是沿着tensor指定维度进行拼接

不同点

直接通过实例进行展示吧…例如:现在分别输入三个tensor,a,b,c,如下图所示值

# input
import torch

a = torch.IntTensor([[1,2,3],[4,5,6],[7,8,9]])  # torch.Size([3, 3])
b = torch.IntTensor([[11,22,33],[44,55,66],[77,88,99]])  # torch.Size([3, 3])
c = torch.IntTensor([[111,222,333],[444,555,666],[777,888,999]])  # torch.Size([3, 3])
'''
a: tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]], dtype=torch.int32)
b: tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]], dtype=torch.int32)
c: tensor([[111, 222, 333],
        [444, 555, 666],
        [777, 888, 999]], dtype=torch.int32)
'''
  1. torch.cat():只能在已有维度上进行拼接!!——串联
# cat
cat_0 = torch.cat((a,b,c),dim=0)  # torch.Size([9, 3])
cat_1 = torch.cat((a,b,c),dim=1)  # torch.Size([3, 9])
cat_2 = torch.cat((a,b,c),dim=2)  # error
# cat_output
'''
cat_0: 
		tensor([[  1,   2,   3],
		        [  4,   5,   6],
		        [  7,   8,   9],
		        [ 11,  22,  33],
		        [ 44,  55,  66],
		        [ 77,  88,  99],
		        [111, 222, 333],
		        [444, 555, 666],
		        [777, 888, 999]], dtype=torch.int32)
cat_1: 
		tensor([[  1,   2,   3,  11,  22,  33, 111, 222, 333],
		        [  4,   5,   6,  44,  55,  66, 444, 555, 666],
		        [  7,   8,   9,  77,  88,  99, 777, 888, 999]], dtype=torch.int32)
cat_2: 
     Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

'''
  1. torch.stack():增加一维,同时在指定进行拼接——并联
# stack
stack_0 = torch.stack((a,b,c),dim=0) # torch.Size([3, 3, 3])
stack_1 = torch.stack((a,b,c),dim=1)  # torch.Size([3, 3, 3])
stack_2 = torch.stack((a,b,c),dim=2)  # torch.Size([3, 3, 3])

# stack_output
'''
stack_0: 
		tensor([[[  1,   2,   3],
         [  4,   5,   6],
         [  7,   8,   9]],

        [[ 11,  22,  33],
         [ 44,  55,  66],
         [ 77,  88,  99]],

        [[111, 222, 333],
         [444, 555, 666],
         [777, 888, 999]]], dtype=torch.int32)
stack_1:
		tensor([[[  1,   2,   3],
         [ 11,  22,  33],
         [111, 222, 333]],

        [[  4,   5,   6],
         [ 44,  55,  66],
         [444, 555, 666]],

        [[  7,   8,   9],
         [ 77,  88,  99],
         [777, 888, 999]]], dtype=torch.int32)
stack_2:
		tensor([[[  1,  11, 111],
	         [  2,  22, 222],
	         [  3,  33, 333]],
	
	        [[  4,  44, 444],
	         [  5,  55, 555],
	         [  6,  66, 666]],
	
	        [[  7,  77, 777],
	         [  8,  88, 888],
	         [  9,  99, 999]]], dtype=torch.int32)
	
'''

总结

  • torch.cat()对tensors沿指定维度拼接,返回的tensor的维数不会变——串联
  • torch.stack()同样也是对tensors沿指定维度拼接,但返回的tensor会多一维。如:对两个“1X2“维的tensor在dim=0上stack,则会变为"2X1X2"的tensor;在dim=1上stack,则会变为"1X2X2"的tensor;在dim=2上stack,也会变为"1X2X2"的tensor,虽然与dim=1上stack最终产生的tensor维度相同,但是叠加的维度是不一样的,结果也是不一样的。——并联
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值