pytorch中expand和repeat的区别

expand
import torch
a = torch.Tensor([[1,2,3]])
'''
tensor(
	[[1.,2.,3.]]
)
'''

aa = a.expand(4, 3) # 也可写为a.expand(4, -1) # 对于某一个维度上的值为1的维度,可以在该维度上进行tensor的复制,若大于1则不行
'''
tensor(
	[[1.,2.,3.],
	[1.,2.,3.],
	[1.,2.,3.],
	[1.,2.,3.]]
)
'''

a = torch.Tensor([[1,2,3], [4, 5, 6]])
'''
tensor(
	[[1.,2.,3.],
	 [4.,5.,6.]]
)
'''
a = a.expand(4,6) # 最高几个维度的参数必须和原始shape保持一致,否则报错
'''
RuntimeError: The expanded size of the tensor (6) must match the existing size (3) at non-singleton dimension 1.
'''

aa = a.expand(1,2,3) # 可以在tensor的低维增加更多维度
'''
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''
aaa = a.expand(2,2,3) # 可以在tensor的低维增加更多维度,同时在新增加的低维度上进行tensor的复制
'''
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]],
	 [[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''

aaa = a.expand(2,3,2) # 不可在更高维增加维度,否则报错
'''
RuntimeError: The expanded size of the tensor (2) must match the existing size (3) at non-singleton dimension 2.
'''

aaaa = a.expand(2, -1, -1) # 最高几个维度的参数可以用-1,表示和原始维度一致
'''
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]],
	 [[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''

print(aaaa.storage()) # 存储区的数据,说明expand后的a,aa,aaa,aaaa是共享storage的,只是tensor的头信息区设置了不同的数据展示格式,从而使得a,aa,aaa,aaaa呈现不同的tensor形式
'''
1.0
2.0
3.0
4.0
5.0
6.0
'''
repeat
import torch
a = torch.Tensor([[1,2,3]])
'''
tensor(
	[[1.,2.,3.]]
)
'''

aa = a.repeat(4, 3) # 维度不变,在各个维度上进行数据复制
'''
tensor(
	[[1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.]]
)
'''

a = torch.Tensor([[1,2,3], [4, 5, 6]])
'''
tensor(
	[[1.,2.,3.],
	 [4.,5.,6.]]
)
'''
aa = a.repeat(4,6) # 维度不变,在各个维度上进行数据复制
'''
tensor(
	[[1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.]]
)
'''

aaa = a.repeat(1,2,3) # 可以在tensor的低维增加更多维度,并在各维度上复制数据
'''
tensor(
	[[[1.,2.,3.,1.,2.,3.,1.,2.,3.],
	  [4.,5.,6.,4.,5.,6.,4.,5.,6.],
	  [1.,2.,3.,1.,2.,3.,1.,2.,3.],
	  [4.,5.,6.,4.,5.,6.,4.,5.,6.]]]
)
'''
aaaa = a.repeat(2,3,1) # 可以在tensor的高维增加更多维度,并在各维度上复制数据
'''
tensor(
	[[[1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.]],
	 [[1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.]]]
)
'''

aaaaa = a.repeat(2, 3, -1) 
'''
RuntimeError: Trying to create tensor with negative dimension -3: [2,6,-3]
'''

print(aaaa.storage()) # 存储区的数据,说明repeat后的a,aa,aaa,aaaa是有各自独立的storage的
'''
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
'''
总结

相同:
(1)都可以扩展维度,或在某个维度上进行tensor的复制

区别:
(1)参数意义不同,repeat的参数表示沿某维度的数据复制倍数,可为大于0的任何整数值;expand的参数表示tensor对应的维度上的值,且只有增加新的低维度时表示沿该低维度的数据复制倍数,其他参数必须和原始tensor保持一致
(2)返回的结果的存储区不同,repeat返回的tensor会重新拥有一个独立存储区,而expand返回的tensor则与原始tensor共享存储区

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值