Colorful Sequences
题解
首先,我们可以根据给出的长度为
m
m
m的问题分为三类。
如果序列中的数中已经存在一个长度为
k
k
k的字符两两不同的子串(暂且称之为条件子串),那么其它位置的字符我们是可以随便填的。
由于我们的长度为
m
m
m的子串(不妨称其为目标子串)出现在一个位置就算一次,所以我们只需要枚举在某个位置上出现该目标子串时整个序列的方案数,答案即为
(
n
−
m
+
1
)
k
n
−
m
(n-m+1)k^{n-m}
(n−m+1)kn−m。
如果目标子串中不存在这样的条件子串,且序列中存在两个相同的字符,这就意味着我们的条件子串不可能完全覆盖我们的目标子串,它只能从一段伸出去。
该条件子串既可能覆盖我们的目标子串的一段,也可能完全不覆盖我们目标子串的一段。
但我们可以注意到,条件子串在目标子串上的一段必然是一个小的两两不同子串。
我们可以先将的最大长度预处理出来,然后通过
d
p
dp
dp求出伸长出去一定长度的方案数。
我们记
f
i
,
k
f_{i,k}
fi,k表示在我伸长出去
i
i
i长度的序列时,根尖伸长区的连续不同序列的长度最大为
k
k
k的方案数。
转移应该比较好想,
f
i
,
k
f_{i,k}
fi,k各有
1
1
1种方法转移到
f
i
+
1
,
k
′
(
k
′
∈
[
1
,
k
]
)
f_{i+1,k'}(k'\in[1,k])
fi+1,k′(k′∈[1,k]),剩下的方法都能转移到
f
i
+
1
,
k
+
1
f_{i+1,k+1}
fi+1,k+1,就是看加入的数与哪个数重复了罢了。
可以在转移时实时统计
d
p
i
dp_{i}
dpi,即伸长
i
i
i长度的序列合法的方案数。
将左右端都统计一遍,最后合起来即可得到整个序列的方案数了。
由于总长度固定,没有必要卷积。
如果目标子串是一个两两不同的子串,那么就我们的条件子串就有可能包含我们的目标子串了。
但我们可以发现,由于所有的字符都两两不同,所以我们将这个字符集任意变化一个一一映射,其答案都应该是等价的。
这也就是说,我们可以直接统计所有合法的长度为
n
n
n的序列中,包含长度为
k
k
k的两两不同的序列的个数,再将其除以
A
m
k
A_{m}^{k}
Amk即可。
同样可以用上面的
d
p
dp
dp方法进行统计。
时间复杂度 O ( n k ) O\left(nk\right) O(nk)。
源码
出乎意料的优秀代码,不仅速度快,空间竟然也达到了
K
K
K的级别。
#include<bits/stdc++.h>
using namespace std;
#define MAXN 25005
#define lowbit(x) (x&-x)
#define reg register
#define pb push_back
#define mkpr make_pair
#define fir first
#define sec second
typedef long long LL;
typedef unsigned long long uLL;
typedef long double ld;
typedef pair<int,int> pii;
const int INF=0x3f3f3f3f;
const int mo=1e9+7;
const int inv2=5e8+4;
const int jzm=2333;
const int zero=100000;
const int n1=1000;
const int lim=100000000;
const int orG=3,ivG=332748118;
const double Pi=acos(-1.0);
const double eps=1e-6;
template<typename _T>
_T Fabs(_T x){return x<0?-x:x;}
template<typename _T>
void read(_T &x){
_T f=1;x=0;char s=getchar();
while(s>'9'||s<'0'){if(s=='-')f=-1;s=getchar();}
while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+(s^48);s=getchar();}
x*=f;
}
template<typename _T>
void print(_T x){if(x<0){x=(~x)+1;putchar('-');}if(x>9)print(x/10);putchar(x%10+'0');}
int gcd(int a,int b){return !b?a:gcd(b,a%b);}
int add(int x,int y,int p){return x+y<p?x+y:x+y-p;}
void Add(int &x,int y,int p){x=add(x,y,p);}
int qkpow(int a,int s,int p){int t=1;while(s){if(s&1)t=1ll*t*a%p;a=1ll*a*a%p;s>>=1;}return t;}
int n,k,m,a[MAXN],fac[MAXN],inv[MAXN],ff[MAXN],ans;
int f[2][405],g[2][405],dp[MAXN],gp[MAXN],hp[MAXN];
int ft[2][405],gt[2][405];
int vis[MAXN],cnt;
void init(){
fac[0]=fac[1]=inv[0]=inv[1]=ff[1]=1;
for(int i=2;i<=n;i++)
fac[i]=1ll*i*fac[i-1]%mo,
ff[i]=1ll*(mo-mo/i)*ff[mo%i]%mo,
inv[i]=1ll*ff[i]*inv[i-1]%mo;
}
int C(int x,int y){
if(x<0||y<0||x<y)return 0;
return 1ll*fac[x]*inv[y]%mo*inv[x-y]%mo;
}
signed main(){
//freopen("sequence.in","r",stdin);
//freopen("sequence.out","w",stdout);
read(n);read(k);read(m);init();bool flag=0;
for(int i=1;i<=m;i++)read(a[i]);int ld=0,rd=0;
for(int i=1;i<=m;i++){
vis[a[i]]++;if(vis[a[i]]==1)cnt++;
if(i>k){vis[a[i-k]]--;if(!vis[a[i-k]])cnt--;}
if(cnt==k)flag=1;
}
if(flag){
printf("%d\n",1ll*qkpow(k,n-m,mo)*(n-m+1)%mo);
return 0;
}
for(int i=1;i<=k;i++)vis[i]=0;
for(int i=1;i<=m;i++)if(!vis[a[i]])vis[a[i]]=1,ld++;else break;
for(int i=1;i<=k;i++)vis[i]=0;
for(int i=m;i>0;i--)if(!vis[a[i]])vis[a[i]]=1,rd++;else break;
if(ld==m){
f[1][1]=k;ft[1][1]=k*(m==1);int now=1,las=0;
for(int i=2;i<=n;i++){
swap(now,las);int summ=0,sumt=0;
for(int j=k;j>0;j--)
Add(summ,f[las][j],mo),Add(sumt,ft[las][j],mo),
f[now][j]=add(1ll*(k-j+1)*f[las][j-1]%mo,summ,mo),
ft[now][j]=add(1ll*(k-j+1)*ft[las][j-1]%mo,sumt,mo),
Add(ft[now][j],f[now][j]*(j>=m),mo);
summ=0;sumt=0;
for(int j=k;j>0;j--)
Add(summ,g[las][j],mo),Add(sumt,gt[las][j],mo),
g[now][j]=add(1ll*(k-j+1)*g[las][j-1]%mo,summ,mo),
gt[now][j]=add(1ll*(k-j+1)*gt[las][j-1]%mo,sumt,mo),
Add(gt[now][j],g[now][j]*(j>=m),mo);
Add(g[now][k],f[now][k],mo);f[now][k]=0;
Add(gt[now][k],ft[now][k],mo);ft[now][k]=0;
}
for(int i=1;i<=k;i++)Add(ans,gt[now][i],mo);
printf("%d\n",1ll*ans*inv[m]%mo*qkpow(C(k,m),mo-2,mo)%mo);
return 0;
}
f[0][ld]=1;int now=0,las=1;
for(int i=1;i<=n-m;i++){
int summ=0;swap(now,las);
for(int j=k;j>0;j--)Add(summ,f[las][j],mo),
f[now][j]=add(1ll*(k-j+1)*f[las][j-1]%mo,summ,mo);
dp[i]=add(1ll*k*dp[i-1]%mo,f[now][k],mo);f[now][k]=0;
for(int j=1;j<=k;j++)f[las][j]=0;
}
g[0][rd]=1;now=0;las=1;
for(int i=1;i<=n-m;i++){
int summ=0;swap(now,las);
for(int j=k;j>0;j--)Add(summ,g[las][j],mo),
g[now][j]=add(1ll*(k-j+1)*g[las][j-1]%mo,summ,mo);
gp[i]=add(1ll*k*gp[i-1]%mo,g[now][k],mo);g[now][k]=0;
for(int j=1;j<=k;j++)g[las][j]=0;
}
for(int i=0;i<=n-m;i++){
int tmpl=add(qkpow(k,i,mo),mo-dp[i],mo);
int tmpr=add(qkpow(k,n-m-i,mo),mo-gp[n-m-i],mo);
Add(ans,qkpow(k,n-m,mo),mo);
Add(ans,mo-1ll*tmpl*tmpr%mo,mo);
}
printf("%d\n",ans);
return 0;
}