题目链接
题目解法
如果要表示每一维的状态的话,显然装不下
考虑用简洁的方式表示状态,且可转移
有一个很巧妙的状态表示,首先考虑逆转操作,把计数器归零变为变成目标数字
令
d
p
s
dp_s
dps 表示
n
n
n 个计数器中当前数字离目标数字的最大值
考虑
a
r
<
s
≤
a
r
+
1
a_r<s\le a_{r+1}
ar<s≤ar+1 的
r
r
r,对于
[
1
,
r
]
[1,r]
[1,r] 的计数器,如果选择,会使
s
−
1
s-1
s−1;对于
[
r
+
1
,
n
]
[r+1,n]
[r+1,n] 的计数器
x
x
x,如果选择,会使
s
=
a
x
s=a_x
s=ax
所以这个状态表示是好转移的:
d
p
s
=
r
n
d
p
s
−
1
+
1
n
∑
i
=
r
+
1
n
d
p
a
i
+
1
dp_s=\frac{r}{n}dp_{s-1}+\frac{1}{n}\sum_{i=r+1}^{n}dp_{a_i}+1
dps=nrdps−1+n1∑i=r+1ndpai+1
考虑把
d
p
s
−
1
dp_{s-1}
dps−1 提到前面,使
d
p
dp
dp 转移式可以递推
d
p
s
−
1
=
n
r
d
p
s
−
1
r
∑
i
=
r
+
1
n
d
p
a
i
−
n
r
dp_{s-1}=\frac{n}{r}dp_s-\frac{1}{r}\sum_{i=r+1}^{n}dp_{a_i}-\frac{n}{r}
dps−1=rndps−r1∑i=r+1ndpai−rn
考虑到边界是
d
p
0
=
0
dp_0=0
dp0=0,和递推方向是相反的,考虑如何变为一致
这里有一个很妙的
t
r
i
c
k
trick
trick,令
f
i
=
d
p
a
n
−
d
p
i
f_i=dp_{a_n}-dp_i
fi=dpan−dpi
化简上面的式子可得:
f
s
−
1
=
n
r
f
s
−
1
r
∑
i
=
r
+
1
n
f
a
i
+
n
r
f_{s-1}=\frac{n}{r}f_s-\frac{1}{r}\sum_{i=r+1}^{n}f_{a_i}+\frac{n}{r}
fs−1=rnfs−r1∑i=r+1nfai+rn
考虑上面的式子对于
a
i
a_i
ai 到
a
i
+
1
a_{i+1}
ai+1 之间的转移式都是相同的,所以考虑对于每一段矩阵乘法优化
时间复杂度
O
(
2
3
n
l
o
g
A
n
)
O(2^3nlogA_n)
O(23nlogAn)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N(200100),P(998244353);
struct Matrix{
int n,m,a[2][2];
};
Matrix operator *(const Matrix &A,const Matrix &B){
Matrix C;C.n=A.n,C.m=B.m;memset(C.a,0,sizeof(C.a));
for(int i=0;i<C.n;i++) for(int j=0;j<C.m;j++)
for(int k=0;k<A.m;k++) C.a[i][j]=(C.a[i][j]+A.a[i][k]*B.a[k][j])%P;
return C;
}
int n,a[N];
inline int read(){
int FF=0,RR=1;
char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') RR=-1;
for(;isdigit(ch);ch=getchar()) FF=(FF<<1)+(FF<<3)+ch-48;
return FF*RR;
}
int qmi(int a,int b){
int res=1;
for(;b;b>>=1){
if(b&1) res=res*a%P;
a=a*a%P;
}
return res;
}
signed main(){
n=read();
for(int i=1;i<=n;i++) a[i]=read();
int sum=0,f=0;
for(int i=n;i>1;i--){
int inv=qmi(i-1,P-2);
sum=(sum+f)%P;
int c=(-inv*sum%P+n*inv%P+P)%P;
Matrix B;B.n=2,B.m=2,B.a[0][0]=n*inv%P,B.a[0][1]=0,B.a[1][0]=B.a[1][1]=1;
Matrix A;A.n=1,A.m=2,A.a[0][0]=f,A.a[0][1]=c;
int times=a[i]-a[i-1];
for(;times;times>>=1ll){
if(times&1) A=A*B;
B=B*B;
}
f=A.a[0][0];
}
printf("%lld",f);
return 0;
}