python基本语法——数据拼接操作(numpy & torch)
一、数组拼接
Torch
1. torch.cat
import torch
a = torch.ones(10)
b = torch.zeros(10)
a,b
(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
c = torch.cat((a,b),0)
d = torch.cat((a,b),-1)
c,d
(tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0.]),
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0.]))
m = torch.rand(2,3)
n = torch.rand(2,3)
m,n
(tensor([[0.4305, 0.5858, 0.6041],
[0.3689, 0.7339, 0.3004]]),
tensor([[0.3952, 0.5579, 0.6647],
[0.3820, 0.1805, 0.4863]]))
s1 = torch.cat((m,n),0) #dim = 0 按列拼接
s1
tensor([[0.4305, 0.5858, 0.6041],
[0.3689, 0.7339, 0.3004],
[0.3952, 0.5579, 0.6647],
[0.3820, 0.1805, 0.4863]])
s2 = torch.cat((m,n),dim=1) # dim=1, 按行拼接
s2
tensor([[0.4305, 0.5858, 0.6041, 0.3952, 0.5579, 0.6647],
[0.3689, 0.7339, 0.3004, 0.3820, 0.1805, 0.4863]])
s3 = torch.cat((s2,m),dim=1)
s3
tensor([[0.4305, 0.5858, 0.6041, 0.3952, 0.5579, 0.6647, 0.4305, 0.5858, 0.6041],
[0.3689, 0.7339, 0.3004, 0.3820, 0.1805, 0.4863, 0.3689, 0.7339, 0.3004]])
示例分析:
a是[10]的列表,a一直在变化,想把这些变化都拼接到一个数组中
a_all = []
for i in range(5):
a = torch.rand(3)
if i==0: a_all = a
else: a_all = torch.cat((a_all,a),dim=0)
a_all
tensor([0.2426, 0.6465, 0.6805, 0.3972, 0.0939, 0.9206, 0.3735, 0.6619, 0.4411,
0.6900, 0.4391, 0.4374, 0.4946, 0.9330, 0.0059])
a.shape
torch.Size([3])
a.reshape(-1