问了Fedformer作者关于MOE模块的解释
MOE模块是整体框架黄色的部分
class moving_avg(nn.Module):
"""
Moving average block to highlight the trend of time series
"""
def __init__(self, kernel_size, stride):
super(moving_avg, self).__init__()
self.kernel_size = kernel_size
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)
def forward(self, x):
# padding on the both ends of time series
front = x[:, 0:1, :].repeat(1, self.kernel_size - 1-math.floor((self.kernel_size - 1) // 2), 1)#这里不一样
end = x[:, -1:, :].repeat(1, math.floor((self.kernel_size - 1) // 2), 1)#这里不一样
x = torch.cat([front, x, end], dim=1)
x = self.avg(x.permute(0, 2, 1))
x = x.permute(0, 2, 1)
return x
class series_decomp(nn.Module):
"""
Series decomposition block
"""
# kernel_size = [24]
def __init__(self, kernel_size):
super(series_decomp, self).__init__()
self.moving_avg = [moving_avg(kernel, stride=1) for kernel in kernel_size]
self.layer = torch.nn.Linear(1, len(kernel_size))
def forward(self, x):
moving_mean=[]
for func in self.moving_avg:
moving_avg = func(x)
moving_mean.append(moving_avg.unsqueeze(-1))
moving_mean=torch.cat(moving_mean,dim=-1)
moving_mean = torch.sum(moving_mean*nn.Softmax(-1)(self.layer(x.unsqueeze(-1))),dim=-1)
res = x - moving_mean
return res, moving_mean
读懂代码参考的资料
Python中将函数作为列表元素_思考实践的博客-CSDN博客
参考资料