题目
https://ac.nowcoder.com/acm/contest/11260/C
分析
(截图题解的,在推了……)
答案就是这三部分乘起来,前两部分好说,最后那部分,注意到值域比较小,于是统计一下在数组 a[]
中每种数的个数,然后正着一个生成函数,倒着一个生成函数,相乘,得到的系数就是差是某个数对应的 a[i]-a[j]
数量,然后乘出来就行。
时间复杂度是 O(n lgn)
的。
代码
#include <bits/stdc++.h>
using namespace std;
#define Ha 998244353
#define MAXN 500005
#define MAXVAL 1000005
int n,t;
int a[MAXN];
long long ans,ans1,ans2,ans3;
long long ksm(long long x, long long k)
{
long long ret=1;
for (; k; x=x*x%Ha,k>>=1)
if (k&1) ret=ret*x%Ha;
return ret;
}
long long inv3;
long long pow3[MAXVAL*6];
long long powinv3[MAXVAL*6];
void pre()
{
inv3=ksm(3,Ha-2);
for (int i=1; i<=t; i<<=1)
pow3[i]=ksm(3,(Ha-1)/i);
for (int i=1; i<=t; i<<=1)
powinv3[i]=ksm(inv3,(Ha-1)/i);
}
void bit_reverse(int n, vector<long long> &r)
{
for (int i=0,j=0; i<n; i++) {
if (i>j) swap(r[i],r[j]);
for (int l=n>>1; (j^=l)<l; l>>=1);
}
}
void NTT(int n, vector<long long> &r, long long f)
{
bit_reverse(n, r);
for (int i=2; i<=n; i<<=1) {
int m=i>>1;
for (int j=0; j<n; j+=i) {
long long w=1,wn=pow3[i];
if (f==-1) wn=powinv3[i];
for (int k=0; k<m; k++) {
long long z=r[j+m+k]*w%Ha;
r[j+m+k]=(r[j+k]-z+Ha)%Ha;
r[j+k]=(r[j+k]+z)%Ha;
w=w*wn%Ha;
}
}
}
if (f==-1) {
long long inv=ksm(n, Ha-2);
for (int i=0; i<n; i++)
r[i]=r[i]*inv%Ha;
}
}
void solve()
{
int m=MAXVAL*2;
t=1;
while (t<=m) t<<=1;
long long tmp=1;
vector<long long> A(MAXVAL*6),B(MAXVAL*6);
ans1=ans2=ans3=1;
scanf("%d",&n);
for (int i=1; i<=n; i++) {
scanf("%d",&a[i]);
ans2=(ans2*(a[i]+1))%Ha;
tmp=tmp*i%Ha;
ans1=ans1*tmp%Ha;
A[a[i]]++;
B[MAXVAL-a[i]]++;
}
ans1=ksm(ans1,Ha-2);
pre();
NTT(t, A, 1);
NTT(t, B, 1);
for (int i=0; i<t; i++)
A[i]=A[i]*B[i]%Ha;
NTT(t, A, -1);
for (int i=2; i<MAXVAL; i++)
ans3=ans3*ksm(i,A[MAXVAL+i])%Ha;
ans=ans1*ans2%Ha*ans3%Ha;
printf("%lld\n",ans);
}
int main()
{
solve();
return 0;
}