torch.tensor.expand
先看招
import torch
x = torch.tensor([[1], [2], [3]])
print(x.size())
print(x.expand(3, 4))
print(x.expand(-1, 4)) # -1 means not changing the size of that dimension,所以原来是3,现在仍然是3,故和上述等价
说白了,就是复制!!!怎么复制呢?原来是[3,1],现在要变成[3,4],所以是对原tensor中第二个维度里面的数进行复制!!
要求:
被扩张的那个维度必须只有一个数!!也就是说size必须是1!!,所以原tensor必须size是[3,1],不可以是[3,2],否则报错。即:
tensor with singleton dimensions expanded to a larger size.
torch.tensor.repeat
同样都是复制,这个比上面这个好用。
上面这个功能可以如下实现:
x = torch.tensor([[1], [2], [3]])
print(x.size())
print(x.repeat(1, 4))#用法不一样的地方,不变的地方用1表示,而不是-1.
而且其不需要扩张的维度严格要求为1,例如可以是[3,2],例如:
x = torch.tensor([[1,1], [2,1], [3,0]])
print(x.size())
print(x.repeat(1, 4))
这才是真正的复制啊。