一、张量合并
1、cat:合并
bg:有一个学校每个班级有32名学生,每个班的学生都有8门课程
[3,32,8]表示1到3班里同学8门课的成绩情况
[5,32,8]表示4到8班里同学8门课的成绩情况
那么如果我们想把这8个班的同学成绩情况统计为一张表怎么办呢? 这里我们就用到了cat()
a = torch.rand(3, 32, 8)
#b:5个班里32名同学8们课的成绩
b = torch.rand(5, 32, 8)
#res:总共8个班里32名同学8们课的成绩
#参数表示将a, b在0维度上合并
res = torch.cat([a, b], dim=0)
print(res.shape)
2、stack:新增维度
如果我们有班级1的成绩表[32,8] 和班级2的成绩表[32,8] 我们又该如何合并呢?根据上面的理解,显然是应该再建一个维度来表示班级信息的。这时就用到了stack()
a = torch.rand(32, 8)
b = torch.rand(32, 8)
#新增的维度维度0表示 班级1的成绩,维度2表示班级2的成绩
res = torch.stack([a, b], dim=0)
print(res.shape)
二、张量切分
1、split(len, dim= )
按照每段长度为len切分,最后一段可以不满足len
a = torch.rand(5, 32, 8)
aa, b