首先我们需要想出n^2的算法,
我们考虑这样一组样例
5
1 2 3 4 5
计算期望的时候我们需要这样考虑
当计算到a1,a5的时候他的概率是这样的,(2!*3!)/ 5!
因此我们可以看出这个概率计算仅仅与区间长度有关系。
我们枚举出所有要计算的区间
格式: 【 区间长度:区间 】
5 : (1,5)
4 :(1,4)(2,5)
3:(1,3)(2,4)(3,5)
2:(1,2)(2,3)(3,4)(4,5)
因此我们将区间倒置
1 2 3 4 5
5 4 3 2 1
卷积之后答案
1 2 3 4 5 6 7 8 9 10
这样长度为5的区间答案正好累计在了ans[1]上
长度为4的区间答案正好累计在了ans[2]上
……
由于用到了取模,因此用NTT进行一下卷积即可
最后只取从1-(n-2)的答案即可。
code:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int G=3;
const int NUM=20;
const int P = 998244353;
const int MOD = 998244353;
int wn[20];
int mul(int x,int y,int P) {
return (ll)x*y%P;
}
int PowMod(int a,int b) {
int res=1;
a%=P;
while(b) {
if(b&1)res=mul(res,a,P);
a=mul(a,a,P);
b>>=1;
}
return res;
}
void GetWn() {
for(int i=0; i<NUM; i++) {
int t=1<<i;
wn[i]=PowMod(G,(P-1)/t);
}
}
void Change(int a[],int len) {
int i,j,k;
for(i=1,j=len/2; i<len-1; i++) {
if(i<j)swap(a[i],a[j]);
k=len/2;
while(j>=k) {
j-=k;
k/=2;
}
if(j<k)j+=k;
}
}
void NTT(int a[],int len,int on) {
Change(a,len);
int id=0;
for(int h=2; h<=len; h<<=1) {
id++;
for(int j=0; j<len; j+=h) {
int w=1;
for(int k=j; k<j+h/2; k++) {
int u=a[k]%P;
int t=mul(a[k+h/2],w,P);
a[k]=(u+t)%P;
a[k+h/2]=((u-t)%P+P)%P;
w=mul(w,wn[id],P);
}
}
}
if(on==-1) {
for(int i=1; i<len/2; i++)swap(a[i],a[len-i]);
int inv=PowMod(len,P-2);
for(int i=0; i<len; i++)a[i]=mul(a[i],inv,P);
}
}
const int N = 100005;
int n, len, w[N<<2], v[N<<2];
int F[N], Finv[N], inv[N];
void init() {
inv[1] = 1;
for (int i = 2; i < N; i++) {
inv[i] = (P-P/i) * 1ll * inv[P%i] % MOD;
}
F[0] = Finv[0] = 1;
for (int i = 1; i < N; i++) {
F[i] = F[i-1] * 1ll * i % MOD;
Finv[i] = Finv[i-1] * 1ll * inv[i] % MOD;
}
}
int main() {
GetWn();
init();
scanf("%d", &n);
for (int i = 0; i < n; i++) scanf("%d", &w[i]);
for (int i = 0; i < n; i++) v[i] = w[n-1-i];
len = 1;
while (len < (n<<1)) len <<= 1;
for (int i = n; i < len; i++) w[i] = v[i] = 0;
NTT(w, len, 1);
NTT(v, len, 1);
for (int i = 0; i < len; i++) w[i] = mul(w[i], v[i], P);
NTT(w, len, -1);
ll ans = 0;
for (int i = 1; i <= n-2; i++) {
ans = (ans + 1ll * w[n-2-i] * F[i] % MOD * Finv[i+2] % MOD) % MOD;
}
printf("%lld\n", 2*ans%MOD);
}