'''
autoformer的auto-correlation的输入shape除了batch,其他维度与输出shape应该一样
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
input_x = [1,2,3]
query = input_x
key = input_x
input_x = torch.tensor(input_x)
query = torch.tensor(query)
key = torch.tensor(key)
tensor([ 6.0000+0.0000j, -1.5000+0.8660j])
tensor([ 6.0000+0.0000j, -1.5000+0.8660j])
tensor([36.+0.0000e+00j, 3.+2.9802e-08j])
tensor([19.5000, 16.5000])
q_fft = torch.fft.rfft(query,dim=-1)
k_fft = torch.fft.rfft(key,dim=-1)
print(q_fft)
print(k_fft)
res = q_fft * torch.conj(k_fft)
print(res)
corr = torch.fft.irfft(res,dim=-1)
print(corr)
tensor([2, 4, 6])
'''
我这里就假设咱们的topk正好全部就是咱们Corr结果就不去toch.topk排序取出来了
find top k top_k:default 1*log(length)=log96 =4.56 int(4.56)=4
top_k = int(self.factor * math.log(length)) #like informer param=cxlogL said by author Autoformer
'''
test_query = query
test_key = key
temporal_value = query+key
print(temporal_value)
'''
agg聚合操作
'''
roll_test_1 = torch.roll(input_x,-1, -1)
print(roll_test_1)
roll_test_2 = torch.roll(input_x,-2, -1)
print(roll_test_2)
tmp_corr = torch.softmax(corr, dim=-1)
print(corr)
print(tmp_corr)
final_ac_output = tmp_corr[0]*roll_test_1+tmp_corr[1]*roll_test_2
print(final_ac_output)
tensor([2, 3, 1])
tensor([3, 1, 2])
tensor([19.5000, 16.5000])
tensor([0.9526, 0.0474])
tensor([2.0474, 2.9051, 1.0474])
修正
'''
autoformer的auto-correlation的输入shape与输出shape一样
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
input_x = [1,2,3,4]#输入长度必须是偶数,不然在x喂给fft计算的时候会出现偏差,这与fft操作的序列长度一般是2^n,具体自行查找资料
query = input_x
key = input_x
input_x = torch.tensor(input_x)
query = torch.tensor(query)
key = torch.tensor(key)
q_fft = torch.fft.rfft(query,dim=-1)
k_fft = torch.fft.rfft(key,dim=-1)
print(q_fft)
print(k_fft)
res = q_fft * torch.conj(k_fft)
print(res)
corr = torch.fft.irfft(res,dim=-1)
print(corr)
tensor([10.+0.j, -2.+2.j, -2.+0.j])
tensor([10.+0.j, -2.+2.j, -2.+0.j])
tensor([100.+0.j, 8.+0.j, 4.+0.j])
tensor([30., 24., 22., 24.])
roll
1234
2341
3412
4123
#30 = 1*1+2*2+3*3+4*4
#24 = 1*2+2*3+3*4+4*1
#22 = 1*3+2*4+3*1+4*2
#24 = 1*4+2*1+3*2+4*3
'''
agg聚合操作
'''
roll_test_1 = torch.roll(input_x,0, -1)
print(roll_test_1)
roll_test_2 = torch.roll(input_x,-1, -1)
print(roll_test_2)
roll_test_3 = torch.roll(input_x,-2, -1)
print(roll_test_3)
roll_test_4 = torch.roll(input_x,-3, -1)
print(roll_test_4)
roll = [roll_test_1,roll_test_2,roll_test_3,roll_test_4]
tmp_corr = torch.softmax(corr, dim=-1)
print(corr)
print(tmp_corr)
final_ac_output = tmp_corr[0]*roll[0]+tmp_corr[1]*roll[1]+tmp_corr[2]*roll[2]+tmp_corr[3]*roll[3]
print(final_ac_output)
tensor([1, 2, 3, 4])
tensor([2, 3, 4, 1])
tensor([3, 4, 1, 2])
tensor([4, 1, 2, 3])
tensor([30., 24., 22., 24.])
tensor([9.9474e-01, 2.4657e-03, 3.3370e-04, 2.4657e-03])
tensor([1.0105, 2.0007, 2.9993, 3.9895])