问题描述
在处理数据的时候遇到一个for循环中生成多个【max_len*max_len】的二维矩阵,现需要将这些矩阵在第一维上进行堆叠,形成一个新的【batch * max_len * max_len】三维矩阵
实现过程
a = torch.ones(3, 3) # 假设生成的矩阵形状为3*3
c = [] # 定义一个空列表用于存储矩阵
for i in range(3):
a = a
c.append(a.unsqueeze(0))
# 使用cat方法可之间实现该操作
c = torch.cat(c, dim=0)
print(c.size())
输出c的形状:
torch.Size([3, 3, 3])