expand 函数
expand(*sizes) -> Tensor
*sizes(torch.Size or int) - the desired expanded size
Returns a new view of the self tensor with singleton dimensions expanded to a larger size.
expand
用于扩展tensor数据。但有以下注意点:
- 该函数
不复制数据
- 扩展时只在能度数是1的维度上扩展
- 生成的对象与原对象共享内存
import torch
a = torch.tensor([1,2])
b = a.expand(2,-1) # -1 代表此维度不变
print('b : ', b)
print('\nafter modifying b')
b[0][0]=10
print('a : ', a)
print('b : ', b)
b : tensor([[1, 2],
[1, 2]])
after modifying b
a : tensor([10, 2])
b : tensor([[10, 2],
[10, 2]])
c = a.expand(2,4) # tesor 'a'最后一维维度是2,所以扩展时出错
RuntimeError: The expanded size of the tensor (4) must match the existing size (2) at non-singleton dimension 1. Target sizes: [2, 4]. Tensor sizes: [2]
repeat 函数
repeat(*sizes) -> Tensor
*size(torch.Size or int) - The number of times to repeat this tensor along each dimension.
Repeats this tensor along the specified dimensions.
返回tensor在某个维度上扩展后的张量.注意:
- 此函数会生成新的数据变量,和原tensor
不共享内存
d = a.repeat(2,2)
print('d: ', d)
d[0][0] = 10
print('\nafter modifying d ')
print('a: ', a)
print('d: ', d)
a.repeat(2,4) # 此参数 expand() 函数不通过
d: tensor([[1, 2, 1, 2],
[1, 2, 1, 2]])
after modifying d
a: tensor([1, 2])
d: tensor([[10, 2, 1, 2],
[ 1, 2, 1, 2]])
tensor([[1, 2, 1, 2, 1, 2, 1, 2],
[1, 2, 1, 2, 1, 2, 1, 2]])