题意
定义一个长度为n,字符集大小为k的序列是好的,当且仅当其中存在一个长度为k的子串满足1到k每个数在这里面恰好出现一次。现在给一个长度为m的序列a,问在所有好的序列里面,a作为子串的出现次数的和。
m≤n≤25000,k≤400
m
≤
n
≤
25000
,
k
≤
400
分析
我们考虑用在所有序列中出现的次数和减去在所有不好的序列里面的出现次数和。
现在问题在于如何求出所有不好的序列里面a的出现次数和。
分三种情况,如果a本身就是好的,那么答案就是0。
如果a中每个元素两两不同,因为元素之间是没有区别的,那么我们可以先求出在所有不好的序列中,有多少个长度为m的子串满足其中元素两两不同,然后再除以
k!(k−m)!
k
!
(
k
−
m
)
!
.
然后可以dp,设
fi,j
f
i
,
j
表示所有长度为
i
i
的序列中,满足末尾长度为的子串中元素两两不同,但长度为
j+1
j
+
1
的子串不满足的方案。
可以通过前缀和优化到
O(nk)
O
(
n
k
)
。
如果a中有相同元素的话,显然a前面的部分和后面的部分是互不干扰的,那么就分别dp,然后枚举a所在的位置进行计数即可。
时间复杂度
O(nk)
O
(
n
k
)
。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
typedef long long LL;
const int N=25005;
const int MOD=1000000007;
int n,k,m,a[N],f[N][405],g[N][405],ls[405];
int ksm(int x,int y)
{
int ans=1;
while (y)
{
if (y&1) ans=(LL)ans*x%MOD;
x=(LL)x*x%MOD;y>>=1;
}
return ans;
}
int jc(int n)
{
int ans=1;
for (int i=1;i<=n;i++) ans=(LL)ans*i%MOD;
return ans;
}
int main()
{
scanf("%d%d%d",&n,&k,&m);
for (int i=1;i<=m;i++) scanf("%d",&a[i]);
int mx=0,now=0,tot=(LL)(n-m+1)*ksm(k,n-m)%MOD;
for (int i=1;i<=m;i++)
{
now=std::min(now+1,i-ls[a[i]]);
ls[a[i]]=i;mx=std::max(mx,now);
}
if (mx==k) {printf("%d",tot);return 0;}
if (mx==m)
{
f[0][0]=1;
for (int i=1;i<=n;i++)
{
int s=0,t=0;
for (int j=k-1;j>=1;j--)
{
(s+=f[i-1][j])%=MOD;
(t+=g[i-1][j])%=MOD;
(f[i][j]+=(LL)f[i-1][j-1]*(k-j+1)%MOD)%=MOD;
(g[i][j]+=(LL)g[i-1][j-1]*(k-j+1)%MOD)%=MOD;
(f[i][j]+=s)%=MOD;
(g[i][j]+=t)%=MOD;
if (j>=m) (g[i][j]+=f[i][j])%=MOD;
}
}
int w=0;
for (int i=1;i<=k;i++) (w+=g[n][i])%=MOD;
w=(LL)w*jc(k-m)%MOD*ksm(jc(k),MOD-2)%MOD;
printf("%d",(tot+MOD-w)%MOD);
}
else
{
int u=0,v=0;
memset(ls,0,sizeof(ls));
for (int i=1;i<=m;i++)
if (!ls[a[i]]) u++,ls[a[i]]=1;
else break;
memset(ls,0,sizeof(ls));
for (int i=m;i>=1;i--)
if (!ls[a[i]]) v++,ls[a[i]]=1;
else break;
f[0][u]=g[0][v]=1;
for (int i=1;i<=n;i++)
{
int s=0,t=0;
for (int j=k-1;j>=1;j--)
{
(s+=f[i-1][j])%=MOD;
(t+=g[i-1][j])%=MOD;
(f[i][j]+=(LL)f[i-1][j-1]*(k-j+1)%MOD)%=MOD;
(g[i][j]+=(LL)g[i-1][j-1]*(k-j+1)%MOD)%=MOD;
(f[i][j]+=s)%=MOD;
(g[i][j]+=t)%=MOD;
}
}
for (int i=0;i+m<=n;i++)
{
int s=0,t=0;
for (int j=1;j<k;j++) (s+=f[i][j])%=MOD,(t+=g[n-m-i][j])%=MOD;
(tot+=MOD-(LL)s*t%MOD)%=MOD;
}
printf("%d",tot);
}
return 0;
}