本次的代码笔记来自Time-Series-Library
之前对于处理多个block的方法可能是在res list中每个结果单独乘系数,如:
import torch
#结果1
a = torch.arange(6).reshape(2,3)
#结果2
b = torch.arange(6,12).reshape(2,3)
#结果列表
res = []
res.append(a)
res.append(b)
#权重
weight = torch.tensor([0.2,0.8])
res_final = [res[i]*weight[i] for i in range(len(weight))]
final = sum(res_final) # sum()可直接对嵌套列表求和
print("row_way_result:",final)
print("row_way_result_shape:",final.shape)
row_way_result: tensor([[4.8000, 5.8000, 6.8000],
[7.8000, 8.8000, 9.8000]])
row_way_result_shape: torch.Size([2, 3])
现在提供另外一种思路,即先用torch.stack()将这些结果拼接起来,然后再将权重升维成对应的矩阵,与结果相乘,即:
import torch
#结果1
a = torch.arange(6).reshape(2,3)
#结果2
b = torch.arange(6,12).reshape(2,3)
#结果列表
res = []
res.append(a)
res.append(b)
# stack叠加,变成2*3*2
c = torch.stack(res,dim=-1)
#生成权重矩阵
weight = torch.tensor([0.2,0.8]).unsqueeze(0).unsqueeze(0).repeat(2,3,1)
final = torch.sum(c*weight,dim=-1)
print("column_way_result:",final)
print("column_way_result_shape:",final.shape)
column_way_result: tensor([[4.8000, 5.8000, 6.8000],
[7.8000, 8.8000, 9.8000]])
column_way_result_shape: torch.Size([2, 3])
关于这里为什么不用torch.cat(),而用torch.stack()拼接的原因:
因为这里是需要在dim=0
这个维度对所有结果进行相加,如结果是有两个2*3
的矩阵,我们想通过两个矩阵对应的权重进行加和,形成一个2*3
的矩阵,那么需要用到torch.sum()
这个函数进加和,torch.sum()
加和需要指定维度,如果用了torch.cat()
,那么会变成一个2*6
的矩阵,torch.sum()
不好发挥作用