推公式 + NTT
题意:
给 个正整数 ,求
解法:
首先考虑暴力的做法 ,显然超时,那我们给他化简一下
这个式子就是一个 的套路,一般这种套路题做法如下:
- (第一步)定义一个函数 ,这个函数的意思就是在 里面,有多少个数字等于 ,并且将该函数代入原式得到 (这里的 表示 的最小值和最大值),证明的话可以考虑反向枚举第一个式子的 的值作为第二个式子的 的值。
- (第二步)开始化简式子,考虑卷积形式 ,显然需要凑出一个减法,于是我们用平方差来表示 的次幂得到
- (假的第三步)展开我们化简得到的式子得到 ,但是,显然第二个求和符号后面的式子很想卷积,但是仔细对比发现了问题,第二个求和符号的上标不是 ,所以它并不是一个卷积,于是我们考虑能不能将它变成 (所以这是假的第三点)
- (第二步的后续)显然是可以的,我们知道 上述式子实际上是一个关于主对角线对称的正方形,我们可以使用一个带主对角线的三角形所有数的和 减去 主对角线的和。 所以我们继续第二步化简式子,所以我们最后得到的式子是
- (真的第三步)展开我们化简得到的式子得到 ,现在第二个求和符号后面的式子就是真的卷积了。
- (第四步)我们定义两个函数 是 的卷积,代入上面的式子可以得到
- 由于 ,所以我们可以使用 在 的时间求出卷积 ,并在 的时间内遍历剩下的式子得到最后答案,总复杂度
NTT代码:
#include <bits/stdc++.h>
#define sc scanf
#define pr printf
#define ll long long
using namespace std;
const int MAXN = 4 * 1000040;//4 * max
const ll mod = 998244353, G = 3, Gi = 332748118;//G是原根,Gi是inv(G)
int n, m;
int max_mi[2];//最高次幂
int r[MAXN];
ll a[MAXN], b[MAXN];//系数
ll res[MAXN];//ntt结果
int limit;//单个多项式的最高次幂
inline ll power(ll a, ll b)
{
ll res = 1;
while (b)
{
if (b & 1)
res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
inline void NTT(ll* A, int lens, int op)
{
for (int i = 0; i < lens; i++)
if (i < r[i])
swap(A[i], A[r[i]]);
for (int mid = 1; mid < lens; mid <<= 1)
{
ll Wn = power(op == 1 ? G : Gi, (mod - 1) / (mid << 1));
for (int j = 0; j < lens; j += (mid << 1))
{
ll w = 1;
for (int k = 0; k < mid; k++, w = (w * Wn) % mod)
{
int x = A[j + k], y = w * A[j + k + mid] % mod;
A[j + k] = (x + y) % mod;
A[j + k + mid] = (x - y + mod) % mod;
}
}
}
if (op == -1)
{
ll inv = power(lens, mod - 2);
for (int i = 0; i < lens; i++)
A[i] = (A[i] * inv) % mod;
}
}
int run(ll numa[], ll numb[], int maxa, int maxb, ll sol[])
{
int len1 = maxa + 1;//最高项+1
int len2 = maxb + 1;//最高项+1
int lens = 1;//len是最高的项的二倍
int L = 0;
while (lens < 2 * len1 || lens < 2 * len2)
{
lens <<= 1;
L++;
}
//构造系数表达式
for (int i = 0; i < lens; i++)
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
//系数表达式 转 点值表达式
NTT(numa, lens, 1);
NTT(numb, lens, 1);
//点值表达式相乘
for (int i = 0; i < lens; i++)
sol[i] = (numa[i] * numb[i]) % mod;
//点值表达式 转 系数表达式
NTT(sol, lens, -1);
return maxa + maxb;
}
ll aa[MAXN], num[MAXN];
ll tempa[MAXN];
int main()
{
ll inv2 = 116195171;
assert(inv2 * inv2 % mod == 2);
sc("%d", &n);
m = n;
for (int i = 1; i <= n; i++)
{
sc("%lld", &aa[i]);
num[aa[i]]++;
max_mi[0] = max(1LL * max_mi[0], aa[i]);
max_mi[1] = max(1LL * max_mi[1], aa[i]);
}
for (int i = 0; i <= 100000; i++)
{
a[i] = num[i] * power(inv2, 1LL * i * i) % mod;
tempa[i] = a[i];
b[i] = power(inv2, (-1LL * i * i) % (mod - 1) + (mod - 1));
}
limit = max(max_mi[0], max_mi[1]);
run(a, b, max_mi[0], max_mi[1], res);
ll ans = 0;
for (int i = 0; i <= 100000; i++)
ans = (ans + tempa[i] * res[i]) % mod;
ans = 2 * ans % mod;
for (int i = 0; i <= 100000; i++)
ans = (ans - tempa[i] * tempa[i] % mod + mod) % mod;
pr("%lld\n", ans);
}