Autoformer:自相关机制简单版代码理解

本文介绍了Autoformer中auto-correlation模块的关键步骤,包括输入与输出形状的要求,使用FFT进行频域计算,以及softmax后的聚合操作。通过实例展示了如何处理不同长度输入并利用rolling和softmax进行注意力聚合。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

'''
autoformer的auto-correlation的输入shape除了batch,其他维度与输出shape应该一样
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

input_x = [1,2,3]
query = input_x
key = input_x
input_x = torch.tensor(input_x)
query = torch.tensor(query)
key = torch.tensor(key)

tensor([ 6.0000+0.0000j, -1.5000+0.8660j])
tensor([ 6.0000+0.0000j, -1.5000+0.8660j])
tensor([36.+0.0000e+00j,  3.+2.9802e-08j])
tensor([19.5000, 16.5000])

q_fft = torch.fft.rfft(query,dim=-1)
k_fft = torch.fft.rfft(key,dim=-1)
print(q_fft)
print(k_fft)
res = q_fft * torch.conj(k_fft)
print(res)
corr = torch.fft.irfft(res,dim=-1)
print(corr)

tensor([2, 4, 6])

'''
我这里就假设咱们的topk正好全部就是咱们Corr结果就不去toch.topk排序取出来了
find top k top_k:default 1*log(length)=log96 =4.56 int(4.56)=4
top_k = int(self.factor * math.log(length)) #like informer param=cxlogL said by author Autoformer
'''
test_query = query
test_key = key

temporal_value = query+key
print(temporal_value)

'''
agg聚合操作
'''
roll_test_1 = torch.roll(input_x,-1, -1)
print(roll_test_1)
roll_test_2 = torch.roll(input_x,-2, -1)
print(roll_test_2)

tmp_corr = torch.softmax(corr, dim=-1)
print(corr)
print(tmp_corr)

final_ac_output = tmp_corr[0]*roll_test_1+tmp_corr[1]*roll_test_2
print(final_ac_output)

tensor([2, 3, 1])
tensor([3, 1, 2])
tensor([19.5000, 16.5000])
tensor([0.9526, 0.0474])
tensor([2.0474, 2.9051, 1.0474])

修正

'''
autoformer的auto-correlation的输入shape与输出shape一样
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

input_x = [1,2,3,4]#输入长度必须是偶数,不然在x喂给fft计算的时候会出现偏差,这与fft操作的序列长度一般是2^n,具体自行查找资料
query = input_x
key = input_x
input_x = torch.tensor(input_x)
query = torch.tensor(query)
key = torch.tensor(key)

q_fft = torch.fft.rfft(query,dim=-1)
k_fft = torch.fft.rfft(key,dim=-1)
print(q_fft)
print(k_fft)
res = q_fft * torch.conj(k_fft)
print(res)
corr = torch.fft.irfft(res,dim=-1)
print(corr)

tensor([10.+0.j, -2.+2.j, -2.+0.j])
tensor([10.+0.j, -2.+2.j, -2.+0.j])
tensor([100.+0.j,   8.+0.j,   4.+0.j])
tensor([30., 24., 22., 24.])

roll
1234
2341
3412
4123

#30 = 1*1+2*2+3*3+4*4
#24 = 1*2+2*3+3*4+4*1
#22 = 1*3+2*4+3*1+4*2
#24 = 1*4+2*1+3*2+4*3

'''
agg聚合操作
'''
roll_test_1 = torch.roll(input_x,0, -1)
print(roll_test_1)
roll_test_2 = torch.roll(input_x,-1, -1)
print(roll_test_2)
roll_test_3 = torch.roll(input_x,-2, -1)
print(roll_test_3)
roll_test_4 = torch.roll(input_x,-3, -1)
print(roll_test_4)
roll = [roll_test_1,roll_test_2,roll_test_3,roll_test_4]

tmp_corr = torch.softmax(corr, dim=-1)
print(corr)
print(tmp_corr)

final_ac_output = tmp_corr[0]*roll[0]+tmp_corr[1]*roll[1]+tmp_corr[2]*roll[2]+tmp_corr[3]*roll[3]
print(final_ac_output)

tensor([1, 2, 3, 4])
tensor([2, 3, 4, 1])
tensor([3, 4, 1, 2])
tensor([4, 1, 2, 3])
tensor([30., 24., 22., 24.])
tensor([9.9474e-01, 2.4657e-03, 3.3370e-04, 2.4657e-03])
tensor([1.0105, 2.0007, 2.9993, 3.9895])

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值