题目链接:
题意:给你一个数列a1,a2,....an,让你求出下面这个结果在模998244353情况下的答案。
数据范围
1<=n,ai<=100000
Input
2
1 2
Output
26
题解:首先我们考虑暴力的做法,算算时间复杂度肯定过不去,然后我们考虑怎么优化,一个很明显的做法就是,我们可以单独算出i=j时的答案,即,这个可以在O(n)的时间复杂度范围内直接求出来,然后我们很容易发现我们只需要算出这个值,最后我们乘于2,在减去i=j时算出的答案就是结果。
那么问题来了,我们怎么去快速求上面这个东西(比赛时就挂在了这个关键步骤上)
这是就需要一个很牛逼的东西去处理了——NTT(快速数论变换)
前言:如果不知道NTT或者FFT的话,可以看这两篇文章了解一下:FFT 和 NTT
然后我们可以这样考虑令f(i)表示,即表示数组的值等于i的数的个数
下一步我们进行化简操作,以便进行卷积操作。
化简操作我就不写了,太麻烦了,盗用了一下ICPC南京网络赛C题的标程
我们只需要把最后式子里面的i和j删掉就行了,因为这里没有乘于值(注意:我们枚举的不再是数组的每一个元素,而是最大范围的所有值,就是一个枚举策略的变换)
最后答案是这样的
注意没有乘于i和j。
对于绿色部分的我们可以直接O(n)的范围内求出,红色范围的我们可以利用NTT卷积求出
下一步我们就可以进行愉快的卷积了。
我们很容易发现:有一个多项式Ai,它的系数为
另一个多项式Bi,它的系数为
然后这两个多项式相乘的系数为 (注意j=0时没有什么影响)
因此我们进行直接套用NTT计算,算出答案后,外面的维护一下答案就行了。
代码如下(NTT模板参考自csl的板板):
#pragma GCC optimize(2) #include<iostream> #include<algorithm> #include<cmath> #include<cstring> #include<cstdio> #include<cstdlib> #include<vector> #include<map> #include<set> #include<stack> #include<queue> #define PI atan(1.0)*4 #define E 2.718281828 #define rp(i,s,t) for (i = (s); i <= (t); i++) #define RP(i,s,t) for (i = (t); i >= (s); i--) #define ll long long #define ull unsigned long long #define mst(a,b) memset(a,b,sizeof(a)) #define lson l,m,rt<<1 #define rson m+1,r,rt<<1|1 using namespace std; inline int read() { int a=0,b=1; char c=getchar(); while(c<'0'||c>'9') { if(c=='-') b=-1; c=getchar(); } while(c>='0'&&c<='9') { a=(a<<3)+(a<<1)+c-'0'; c=getchar(); } return a*b; } const ll mod = 998244353; ll quick_pow(ll a, ll b, int mod) { ll res = 1; while(b){ if(b&1) res=(res*a)%mod; b>>=1; a=(a*a)%mod; } return res%mod; } ll inv(ll a, int mod) { return quick_pow(a, mod - 2, mod); } ll sqr2 = 116195171; const int N = 1 << 17; ll wn[N << 2], rev[N << 2]; int NTT_init(int n_) { int step = 0; int n = 1; for (; n < n_; n <<= 1) ++step; for (int i = 1; i < n; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (step - 1)); int g = quick_pow(3, (mod - 1) / n, mod); wn[0] = 1; for (int i = 1; i <= n; ++i) wn[i] = wn[i - 1] * g % mod; return n; } void NTT(ll a[], int n, int f) { for (int i = 0; i < n; i++) if (i < rev[i]) std::swap(a[i], a[rev[i]]); for (int k = 1; k < n; k <<= 1){ for (int i = 0; i < n; i += (k << 1)){ int t = n / (k << 1); for (int j = 0; j < k; j++){ ll w = f == 1 ? wn[t * j] : wn[n - t * j]; ll x = a[i + j]; ll y = a[i + j + k] * w % mod; a[i + j] = (x + y) % mod; a[i + j + k] = (x - y + mod) % mod; } } } if (f == -1) { ll Inv = inv(n, mod); for (int i = 0; i < n; i++) a[i] = a[i] * Inv % mod; } } ll aa[N << 2], bb[N << 2], conv[N << 2]; const int MAX = 1 << 17; ll f[MAX];int a[MAX]; int main(){ int n=read(),i; rp(i,0,n-1) scanf("%d",&a[i]),f[a[i]]++; for(int i=0;i<100001;i++){ aa[i]=f[i]*quick_pow(sqr2,(ll)i*i,mod)%mod; bb[i]=quick_pow(sqr2,(-(ll)i * i) % (mod - 1) + mod - 1,mod); } int m=NTT_init(1<<18); NTT(aa,m,1);NTT(bb,m,1);//把a数组和b数组从系数表示法转换成点值表示法 for(int i=0;i<m;i++)//多项式相乘 conv[i]=1ll*aa[i]*bb[i]%mod; NTT(conv,m,-1);//再把多项式相乘之后的点值表示法转换成系数表示法,方便进行后面的计算 ll ans=0; for(int i=0;i<100001;i++) ans=(ans+1ll*f[i]*quick_pow(sqr2,1ll*i*i,mod)%mod*conv[i]%mod)%mod; ans=(ans*2)%mod; for(int i=0;i<100001;i++) ans=(ans-1ll*f[i]*f[i]%mod*quick_pow(2,1ll*i*i,mod)%mod+mod)%mod; cout<<(ans+mod)%mod<<endl; return 0; }