repeat
pytorch的张量与numpy的数组有很多相似之处,但repeat有明显差异。
1、方向
repeat(repeats,axis)的两个参数,分别表示重复的次数和方向。
对于二维数组,axis = 1表示垂直方向,axis = 2表示水平方向,
与torch.sum和np.repeat都不一样。
2、重复的方式
np的重复方式是各向量分别重复,如012变为 001122
torch的重复方式为整块复制,如 012变为012012
import numpy as np
a = np.arange(6).reshape(2,3)
b = a.repeat(2,0)
print(b)
b = a.repeat(2,1)
print(b)
运行结果
[[0 1 2]
[0 1 2]
[3 4 5]
[3 4 5]]
[[0 0 1 1 2 2]
[3 3 4 4 5 5]]
类似的代码,把np换成torch,方向也变成1和2
import torch
a = torch.arange(6).reshape(2,3)
b = a.repeat(2,1)
print(b)
b = a.repeat(2,2)
print(b)
运行结果
tensor([[0, 1, 2],
[3, 4, 5],
[0, 1, 2],
[3, 4, 5]])
tensor([[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5],
[0, 1, 2, 0, 1, 2],
[3, 4, 5, 3, 4, 5]])