pytorch中expand和repeat的区别

本文深入探讨了PyTorch中`expand`和`repeat`两个操作的区别。`expand`用于在不改变数据的情况下扩展张量的维度,而`repeat`则会复制张量的数据来创建新张量。`expand`在增加低维度时可以复制数据,但高维度参数需与原始张量匹配;`repeat`则可在任意维度复制数据,返回的张量拥有独立的存储区。这两个操作在处理张量维度扩展和复制时具有不同的适用场景。
摘要由CSDN通过智能技术生成
expand
import torch
a = torch.Tensor([[1,2,3]])
'''
tensor(
	[[1.,2.,3.]]
)
'''

aa = a.expand(4, 3) # 也可写为a.expand(4, -1) # 对于某一个维度上的值为1的维度,可以在该维度上进行tensor的复制,若大于1则不行
'''
tensor(
	[[1.,2.,3.],
	[1.,2.,3.],
	[1.,2.,3.],
	[1.,2.,3.]]
)
'''

a = torch.Tensor([[1,2,3], [4, 5, 6]])
'''
tensor(
	[[1.,2.,3.],
	 [4.,5.,6.]]
)
'''
a = a.expand(4,6) # 最高几个维度的参数必须和原始shape保持一致,否则报错
'''
RuntimeError: The expanded size of the tensor (6) must match the existing size (3) at non-singleton dimension 1.
'''

aa = a.expand(1,2,3) # 可以在tensor的低维增加更多维度
'''
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''
aaa = a.expand(2,2,3) # 可以在tensor的低维增加更多维度,同时在新增加的低维度上进行tensor的复制
'''
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]],
	 [[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''

aaa = a.expand(2,3,2) # 不可在更高维增加维度,否则报错
'''
RuntimeError: The expanded size of the tensor (2) must match the existing size (3) at non-singleton dimension 2.
'''

aaaa = a.expand(2, -1, -1) # 最高几个维度的参数可以用-1,表示和原始维度一致
'''
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]],
	 [[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''

print(aaaa.storage()) # 存储区的数据,说明expand后的a,aa,aaa,aaaa是共享storage的,只是tensor的头信息区设置了不同的数据展示格式,从而使得a,aa,aaa,aaaa呈现不同的tensor形式
'''
1.0
2.0
3.0
4.0
5.0
6.0
'''
repeat
import torch
a = torch.Tensor([[1,2,3]])
'''
tensor(
	[[1.,2.,3.]]
)
'''

aa = a.repeat(4, 3) # 维度不变,在各个维度上进行数据复制
'''
tensor(
	[[1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.]]
)
'''

a = torch.Tensor([[1,2,3], [4, 5, 6]])
'''
tensor(
	[[1.,2.,3.],
	 [4.,5.,6.]]
)
'''
aa = a.repeat(4,6) # 维度不变,在各个维度上进行数据复制
'''
tensor(
	[[1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.],
	 [1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.,1.,2.,3.],
	 [4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.,4.,5.,6.]]
)
'''

aaa = a.repeat(1,2,3) # 可以在tensor的低维增加更多维度,并在各维度上复制数据
'''
tensor(
	[[[1.,2.,3.,1.,2.,3.,1.,2.,3.],
	  [4.,5.,6.,4.,5.,6.,4.,5.,6.],
	  [1.,2.,3.,1.,2.,3.,1.,2.,3.],
	  [4.,5.,6.,4.,5.,6.,4.,5.,6.]]]
)
'''
aaaa = a.repeat(2,3,1) # 可以在tensor的高维增加更多维度,并在各维度上复制数据
'''
tensor(
	[[[1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.]],
	 [[1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.],
	  [1.,2.,3.],
	  [4.,5.,6.]]]
)
'''

aaaaa = a.repeat(2, 3, -1) 
'''
RuntimeError: Trying to create tensor with negative dimension -3: [2,6,-3]
'''

print(aaaa.storage()) # 存储区的数据,说明repeat后的a,aa,aaa,aaaa是有各自独立的storage的
'''
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
1.0
2.0
3.0
4.0
5.0
6.0
'''
总结

相同:
(1)都可以扩展维度,或在某个维度上进行tensor的复制

区别:
(1)参数意义不同,repeat的参数表示沿某维度的数据复制倍数,可为大于0的任何整数值;expand的参数表示tensor对应的维度上的值,且只有增加新的低维度时表示沿该低维度的数据复制倍数,其他参数必须和原始tensor保持一致
(2)返回的结果的存储区不同,repeat返回的tensor会重新拥有一个独立存储区,而expand返回的tensor则与原始tensor共享存储区

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值