由于值域范围比较小,考虑直接枚举欧拉函数值:
∑
i
=
1
n
∑
j
=
1
n
p
[
i
]
∗
p
[
j
]
∗
i
∗
j
∗
2
i
j
\displaystyle\sum_{i = 1}^n\sum_{j = 1}^np[i]*p[j]*i*j*2^{ij}
i=1∑nj=1∑np[i]∗p[j]∗i∗j∗2ij
p[i] 表示 i 的出现次数
对
2
i
j
2^{ij}
2ij做一个变换:
2
(
i
+
j
)
2
−
i
2
−
j
2
2
=
2
(
i
+
j
)
2
−
i
2
−
j
2
2^{\frac{(i+j)^2-i^2-j^2}{2}} = {\sqrt 2}^{(i+j)^2-i^2-j^2}
22(i+j)2−i2−j2=2(i+j)2−i2−j2
式子转化为:
∑
i
=
1
n
p
[
i
]
∗
i
∗
2
−
i
2
∑
j
=
1
n
p
[
j
]
∗
j
∗
2
−
j
2
∗
2
(
i
+
j
)
2
\displaystyle\sum_{i = 1}^np[i] * i * \sqrt 2^{-i^2}\sum_{j = 1}^np[j] * j * \sqrt 2^{-j^2} *\sqrt 2^{(i+j)^2}
i=1∑np[i]∗i∗2−i2j=1∑np[j]∗j∗2−j2∗2(i+j)2
令
g
(
i
)
=
p
[
i
]
∗
i
∗
2
−
i
2
g(i) = p[i] * i * \sqrt 2^{-i^2}
g(i)=p[i]∗i∗2−i2,
∑
i
=
1
n
g
(
i
)
∑
j
=
1
n
g
(
j
)
2
(
i
+
j
)
2
\displaystyle\sum_{i = 1}^ng(i)\sum_{j = 1}^ng(j)2^{(i+j)^2}
i=1∑ng(i)j=1∑ng(j)2(i+j)2
枚举 (i + j),令 k = (i + j), f ( k ) = 2 k 2 ∑ i = 1 k g ( i ) ∗ g ( k − i ) \displaystyle f(k) = \sqrt 2^{k^2}\sum_{i = 1} ^ kg(i) * g(k - i) f(k)=2k2i=1∑kg(i)∗g(k−i)
右边显然是
g
g
g 函数与自己的卷积,套用NTT,逐项乘上
2
k
2
\sqrt 2^{k^2}
2k2 并求和得到的即是答案。
2
\sqrt 2
2 可以提前求一下二次剩余直接使用
代码:
//116195171 根号2的二次剩余
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int sqr2 = 116195171;
const int mod = 998244353;
const int maxn = 2e6 + 10;
int t,n;
ll p[maxn],g[maxn];
bool ispri[maxn];
int pri[maxn],phi[maxn];
void sieve(int n) {
ispri[0] = ispri[1] = true;
pri[0] = 0; phi[1] = 1;
for(int i = 2; i <= n; i++) {
if(!ispri[i]) pri[++pri[0]] = i,phi[i] = i - 1;
for(int j = 1; j <= pri[0] && i * pri[j] <= n; j++) {
ispri[i * pri[j]] = true;
if(i % pri[j] == 0) {
phi[i * pri[j]] = phi[i] * pri[j];
break;
} else {
phi[i * pri[j]] = phi[i] * (pri[j] - 1);
}
}
}
}
ll fpow(ll a,ll b) {
ll r = 1;
while(b) {
if(b & 1) r = r * a % mod;
a = a * a % mod;
b >>= 1;
}
return r;
}
void change(ll t[],int len) {
for(int i = 1, j = len / 2; i < len - 1; i++) {
if(i < j) swap(t[i],t[j]);
int k = len / 2;
while(j >= k) {
j -= k;
k /= 2;
}
if(j < k) j += k;
}
}
void NTT(ll t[],int len,int type) {
change(t,len);
for(int s = 2; s <= len; s <<= 1) {
ll wn = fpow(3,(mod - 1) / s);
if(type == -1) wn = fpow(wn,mod - 2);
for(int j = 0; j < len; j += s) {
ll w = 1;
for(int k = 0; k < s / 2; k++) {
ll u = t[j + k],v = t[j + k + s / 2] * w % mod;
t[j + k] = (u + v) % mod;
t[j + k + s / 2] = (u - v + mod) % mod;
w = w * wn % mod;
}
}
}
if(type == -1) {
ll inv = fpow(len,mod - 2);
for(int i = 0; i < len; i++)
t[i] = t[i] * inv % mod;
}
}
int main() {
sieve(maxn - 10);
scanf("%d",&t);
while(t--) {
memset(p,0,sizeof p);
memset(g,0,sizeof g);
scanf("%d",&n);
for(int i = 0; i <= n; i++)
p[phi[i]]++;
for(int i = 0; i <= n; i++) {
ll pi = mod - 1;
ll t = fpow(sqr2,1ll * i * i);
t = fpow(t,mod - 2);
g[i] = 1ll * p[i] * i % mod * t % mod;
}
int len = 1;
while(len <= 2 * n) len <<= 1;
NTT(g,len,1);
for(int i = 0; i <= len; i++)
g[i] = g[i] * g[i] % mod;
NTT(g,len,-1);
for(int i = 0; i <= 2 * n; i++)
g[i] = g[i] * fpow(sqr2,1ll * i * i % (mod - 1)) % mod;
ll sum = 0;
for(int i = 0; i <= 2 * n; i++)
sum = (sum + g[i]) % mod;
printf("%lld\n",sum);
}
return 0;
}