题目传送门
谈点其他的
这个题目翻译好像有亿点问题,我重新发一波。
给定一个长度为 n n n 的排列 p p p。
令其中第 i i i 个位置的权值为最长的包含 i i i 的单调区间的长度(不仅要权值单调,而且要连续)。例如, p = [ 4 , 1 , 2 , 3 , 7 , 6 , 5 ] p=[4,1,2,3,7,6,5] p=[4,1,2,3,7,6,5] 中,第 6 6 6 个位置的权值为 3 3 3( [ 5 , 7 ] [5,7] [5,7] 这个区间单调递减而且连续),第 2 2 2 个位置的权值为 3 3 3( [ 2 , 4 ] [2,4] [2,4] 这个区间单调递增而且连续)。
将这些权值依次拼在一起,就得到了 p p p 的 ⌈ \lceil ⌈ 阶梯序列 ⌋ \rfloor ⌋。
给定 a a a,你需要求出存在多少个 p p p,使得 a a a 为 p p p 的 ⌈ \lceil ⌈ 阶梯序列 ⌋ \rfloor ⌋。
思路
首先可以看出来,就是这个排列的这些单调、权值连续的区间一定是独立的(没有相交的情况),这一步应该很好理解。
然后就可以根据
a
a
a 推出这些区间,比如样例1
:
a
=
{
3
,
3
,
3
,
1
,
1
,
1
}
a=\{3,3,3,1,1,1\}
a={3,3,3,1,1,1},那么这些区间就是
{
[
1
,
3
]
,
[
4
,
4
]
,
[
5
,
5
]
,
[
6
,
6
]
}
\{[1,3],[4,4],[5,5],[6,6]\}
{[1,3],[4,4],[5,5],[6,6]}。
然后我们重新定义一下
a
a
a,就是每个区间的长度依次组成的序列。例如样例1
,
a
=
{
3
,
1
,
1
,
1
}
a=\{3,1,1,1\}
a={3,1,1,1}。
注意接下来所有说的 a a a 都是新定义的 a a a, l e n len len 就是新定义的 a a a 的长度
接下来就可以将
1
∼
n
1\sim n
1∼n 分配给这些区间,那么其实就是
l
e
n
!
len!
len! 种,就考虑让这些区间一个一个地取数,例如样例1
,如果取数的顺序为
3
,
1
,
2
,
4
{3,1,2,4}
3,1,2,4,那么区间
[
1
,
3
]
[1,3]
[1,3] 就分配到了
2
,
3
,
4
2,3,4
2,3,4 这三个数(因为
1
1
1 被排在第一个的区间
3
3
3 取走了),所以这个排列
p
p
p 就是
{
2
,
3
,
4
,
5
,
1
,
6
}
\{2,3,4,5,1,6\}
{2,3,4,5,1,6} 或
{
4
,
3
,
2
,
5
,
1
,
6
}
\{4,3,2,5,1,6\}
{4,3,2,5,1,6}。
这里看出了一个问题,就是所有长度大于等于 2 2 2 的区间都有递增和递减两种情况,所以还要乘上一个 2 x 2^x 2x,这个 x x x 就是长度大于等于 2 2 2 的区间的个数。
但是还会有一个问题,就是分配出来的东西,可能相邻的两个区间可以一起组成一个大的单调、连续的区间,那么要把这种情况排除掉,就可以容斥。
答案就是
至少有 0 0 0 对相邻的区间合并 − - − 至少有 1 1 1 对相邻的区间合并 + + + 至少有2对相邻的区间合并 ⋯ + ( − 1 ) l e n \dots+(-1)^{len} ⋯+(−1)len至少有 ( l e n − 1 ) (len-1) (len−1) 对相邻的区间合并
Solution1
复杂度: O ( n 2 ) O(n^2) O(n2)
考虑 dp。
设 f i , j f_{i,j} fi,j 表示区间编号为 1 ∼ i 1\sim i 1∼i,合并成 j j j 段的值是多少。
转移方程式:
f i , j = { j × 2 × ∑ k = 1 j − 1 f i − 1 , k 2 ≤ j ≤ i ≤ l e n , a j > 1 j × ( ∑ k = 1 j − 1 2 f i − 1 , k − f i − 1 , j − 1 ) 2 ≤ j ≤ i ≤ l e n , a j = 1 f_{i,j}=\left\{ \begin{array}{rcl} j\times2\times\sum\limits_{k=1}^{j-1}f_{i-1,k} & {2\le j\le i\le len,a_j>1}\\ j\times(\sum\limits_{k=1}^{j-1}2f_{i-1,k}-f_{i-1,j-1}) & {2\le j\le i\le len,a_j=1} \end{array} \right. fi,j=⎩⎪⎪⎨⎪⎪⎧j×2×k=1∑j−1fi−1,kj×(k=1∑j−12fi−1,k−fi−1,j−1)2≤j≤i≤len,aj>12≤j≤i≤len,aj=1
初始化:
f 1 , 1 = { 2 a 1 > 1 1 a 1 = 1 f_{1,1}=\left\{ \begin{array}{rcl} 2 & {a_1>1}\\ 1 & {a_1=1}\\ \end{array} \right. f1,1={21a1>1a1=1
f i , 1 = 2 , i ≥ 2 f_{i,1}=2,i\ge2 fi,1=2,i≥2
然后答案就是:
∑ i = 1 l e n ( − 1 ) l e n − i × f n , i × i ! \sum\limits_{i=1}^{len}(-1)^{len-i}\times f_{n,i}\times i! i=1∑len(−1)len−i×fn,i×i!
代码就不贴了,因为我比较勤快懒惰,很不想写代码 QWQ。
Solution
复杂度: O ( n log 2 n ) O(n\log^2n) O(nlog2n)
可以注意到在新的 a a a 序列中合并两个区间时,左边是 x x x 段,右边是 y y y 段,合并起来就是 x + y x+y x+y 段,合并起来的值就是乘积,想到什么???
卷积!!!
可以分治 + NTT,这里模数正好是 98244353 98244353 98244353,原根就是 3 3 3。
但是还有一些问题,就比如说 [ l , m i d ] [l,mid] [l,mid] 的最右边的那个数是大于 1 1 1 的, [ m i d + 1 , r ] [mid+1,r] [mid+1,r] 的最左边的那个数也是大于 1 1 1 的,合并起来就要除以 2 2 2。
两个都是 1 1 1,就要乘以 2 2 2。
然后,细节多得要死,常数大得要死。
最后,注意如果在 ( m i d , m i d + 1 ) (mid,mid+1) (mid,mid+1) 这里分一刀,那么直接卷就好了,如果不分,那么段数要 − 1 -1 −1。
代码
注释看看最后一份代码好了。
然后顺利地 TLE 了……
然后,@275307894a 大佬,给了我一些建议:
每一次乘法时都要 NTT 3 遍,其实可以先 NTT 一遍,乘的时候直接对位相乘就好了。
但是,你发现又 TLE 了。
然后继续优化,发现 NTT 里面有多次调用 qpow,所以可以预处理。
然后发现 div 里面每次都只是除以二,所以不用每次算逆元,每次都直接乘以 2 的逆元就好了。
最后的代码(有注释)
#pragma GCC optimize(3)
#pragma GCC target("avx")
#pragma GCC optimize("Ofast")
#pragma GCC optimize(2)
#include<cstdio>
#include<vector>
#include<algorithm>
#include<ctime>
#define int long long
#define ll long long
using namespace std;
const ll mod=998244353,g=3,ginv=332748118,N=4e5+10,inv2=499122177;
ll qpow(ll x,ll y){
ll ans=1;
while(y){
if(y&1)(ans*=x)%=mod;
(x*=x)%=mod;
y>>=1;
}
return ans;
}
int rr[N];
ll gg[N],gginv[N],lim[N];
ll cnt;
vector<ll> NTT(vector<ll> a,int limit,int type){//用 vector 好写一些吧
for(int i=0;i<limit;i++)if(rr[i]<i)swap(a[rr[i]],a[i]);
for(int mid=1;mid<limit;mid<<=1){
ll wn=type==1?gg[mid<<1]:gginv[mid<<1];//main 里面处理
for(int j=0;j<limit;j+=(mid<<1)){
ll w=1;
for(int k=0;k<mid;k++,(w*=wn)%=mod){
cnt++;ll x=a[j+k],y=w*a[j+k+mid]%mod;//NTT 基本操作,蝴蝶效应
a[j+k]=(x+y)%mod;a[j+k+mid]=(x-y+mod)%mod;
}
}
}if(~type)return a;for(int i=0;i<limit;i++) (a[i]*=lim[limit])%=mod;return a;//别忘了如果是逆回去,要除以 limit,这里也预处理了
}
vector<ll> add(vector<ll> a,vector<ll> b){//相加
int n=a.size(),m=b.size();
while(n<m)a.push_back(0),n++;//要把 0 补起来,不然要 RE
while(n>m)b.push_back(0),m++;
for(int i=0;i<n;i++)(a[i]+=b[i])%=mod;
return a;
}
vector<ll> add_(vector<ll> a,vector<ll> b){//错位加
int n=a.size(),m=b.size();
while(n<m)a.push_back(0),n++;
while(n>m)b.push_back(0),m++;
for(int i=0;i<n-1;i++)(a[i]+=b[i+1])%=mod;
return a;
}
vector<ll> div(vector<ll> a){//除以二
int n=a.size();
for(int i=0;i<n;i++)(a[i]*=inv2)%=mod;
return a;
}
vector<ll> mul(vector<ll> a,ll x){//乘以x
int n=a.size();
for(int i=0;i<n;i++)(a[i]*=x)%=mod;
return a;
}
struct zj{
vector<ll>a[2][2];//分别表示左右端点是否大于 1
zj(){a[0][0].push_back(0),a[0][1].push_back(0),a[1][0].push_back(0),a[1][1].push_back(0);}//分成 0 段的方案数就是 0
};
int n,a[N],b[N],m;
vector<ll> mul(vector<ll> a,vector<ll> b){//卷积
int limit=1,l=0,n=a.size(),m=b.size();
while(limit<=n+m)limit<<=1,l++;
for(int i=0;i<limit;i++)rr[i]=(rr[i>>1]>>1)|((i&1)<<(l-1));
while(n<limit)a.push_back(0),n++;
while(m<limit)b.push_back(0),m++;
a=NTT(a,limit,1);b=NTT(b,limit,1);
for(int i=0;i<limit;i++)(a[i]*=b[i])%=mod;
a=NTT(a,limit,-1);
while(a.size()>1&&a[a.size()-1]==0)a.pop_back();//注意要把无用的 0 去掉,不然会很耗内存和时间
return a;
}
vector<ll> mul_(vector<ll>a,vector<ll>b){//对位乘
int n=a.size(),m=b.size();
while(n<m)a.push_back(0),n++;
while(n>m)b.push_back(0),m++;
for(int i=0;i<n;i++)(a[i]*=b[i])%=mod;
return a;
}
zj merge(int l,int r){
if(l==r){//边界
zj x;
if(a[l]==1)x.a[0][0].push_back(1);
else x.a[1][1].push_back(2);
return x;
}
int mid=(l+r)>>1;
zj x=merge(l,mid),y=merge(mid+1,r);
zj tmp,tmpp;
if(l+1==r){
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)tmp.a[i][j]=add(tmp.a[i][j],mul(x.a[i][ii],y.a[jj][j]));
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)if(ii!=jj)tmp.a[1][1]=add_(tmp.a[1][1],mul(x.a[i][ii],y.a[jj][j]));else if(!ii)tmp.a[1][1]=add_(tmp.a[1][1],mul(mul(x.a[i][ii],y.a[jj][j]),2));else tmp.a[1][1]=add_(tmp.a[1][1],div(mul(x.a[i][ii],y.a[jj][j])));
return tmp;
}
else if(l+2==r){
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)tmp.a[i][j]=add(tmp.a[i][j],mul(x.a[i][ii],y.a[jj][j]));
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)if(ii!=jj)tmp.a[i][1]=add_(tmp.a[i][1],mul(x.a[i][ii],y.a[jj][j]));else if(!ii)tmp.a[i][1]=add_(tmp.a[i][1],mul(mul(x.a[i][ii],y.a[jj][j]),2));else tmp.a[i][1]=add_(tmp.a[i][1],div(mul(x.a[i][ii],y.a[jj][j])));
return tmp;
}
//这两个特盘是因为长度在 3 以内时,合并会改变左右端点的 0/1
int limit=1,llll=0,lena=0,lenb=0;//tmp 是 (mid,mid+1) 切一刀,tmpp 是不切,一定要分开,因为不切要错位加,点值是不能错位加的,要 NTT 回去再错位加起来
for(int i=0;i<2;i++)for(int j=0;j<2;j++)lena=max(lena,(int)x.a[i][j].size());
for(int i=0;i<2;i++)for(int j=0;j<2;j++)lenb=max(lenb,(int)y.a[i][j].size());
while(limit<=lena+lenb)limit<<=1,llll++;
for(int i=0;i<limit;i++)rr[i]=(rr[i>>1]>>1)|((i&1)<<(llll-1));//NTT 基本操作
for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(x.a[i][j].size()<limit)x.a[i][j].push_back(0);//注意补 0
for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(y.a[i][j].size()<limit)y.a[i][j].push_back(0);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(tmp.a[i][j].size()<limit)tmp.a[i][j].push_back(0);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(tmpp.a[i][j].size()<limit)tmpp.a[i][j].push_back(0);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)x.a[i][j]=NTT(x.a[i][j],limit,1);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)y.a[i][j]=NTT(y.a[i][j],limit,1);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)tmp.a[i][j]=add(tmp.a[i][j],mul_(x.a[i][ii],y.a[jj][j]));//NTT 之后要对位乘
for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int ii=0;ii<2;ii++)for(int jj=0;jj<2;jj++)if(ii!=jj)tmpp.a[i][j]=add(tmpp.a[i][j],mul_(x.a[i][ii],y.a[jj][j]));else if(!ii)tmpp.a[i][j]=add(tmpp.a[i][j],mul(mul_(x.a[i][ii],y.a[jj][j]),2));else tmpp.a[i][j]=add(tmpp.a[i][j],div(mul_(x.a[i][ii],y.a[jj][j])));//不能再这里写错位加,我 TM 调了半天!!!一定要 NTT 回去再错位加
for(int i=0;i<2;i++)for(int j=0;j<2;j++)tmp.a[i][j]=NTT(tmp.a[i][j],limit,-1);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)tmpp.a[i][j]=NTT(tmpp.a[i][j],limit,-1);
for(int i=0;i<2;i++)for(int j=0;j<2;j++)tmp.a[i][j]=add_(tmp.a[i][j],tmpp.a[i][j]);//错位加
for(int i=0;i<2;i++)for(int j=0;j<2;j++)while(tmp.a[i][j].size()>1&&tmp.a[i][j][tmp.a[i][j].size()-1]==0)tmp.a[i][j].pop_back();//删去无用 0 ,但是要留一个
return tmp;
}
signed main(){
scanf("%lld",&m);
//上文讲的预处理
for(int mid=1;mid<N;mid<<=1)gg[mid]=qpow(g,(mod-1)/mid),gginv[mid]=qpow(ginv,(mod-1)/mid),lim[mid]=qpow(mid,mod-2);
for(int i=1;i<=m;i++)scanf("%lld",&b[i]);
for(int i=1;i<=m;){
a[++n]=b[i];
for(int j=0;j<b[i];j++){
if(i+j>m)return printf("0\n"),0;
if(b[i+j]!=b[i])return printf("0\n"),0;
}
i=i+b[i];
}//变幻成新定义的 a 数组
zj ans=merge(1,n);
for(int i=0;i<2;i++)for(int j=0;j<2;j++){
while(ans.a[i][j].size()<=n+1)ans.a[i][j].push_back(0);
}
for(int i=1;i<=n;i++)(ans.a[0][0][i]+=ans.a[0][1][i]+ans.a[1][0][i]+ans.a[1][1][i])%=mod;
ll now=1,ANS=0;
for(int i=1;i<=n;i++){//容斥,和上面的公式一样的
(now*=i)%=mod;
(ANS+=((n-i+1)&1?1:-1)*now*ans.a[0][0][i]%mod)%=mod;
(ANS+=mod)%=mod;
}
return printf("%lld",ANS),0;
}
最后说点什么
第一次发黑题题解,第一次独立做出黑题。
需要卡常,不知道为什么我跑了 2.4 2.4 2.4 min,别人都是 10 10 10 s。