正难则反,考虑计算所有不colorful的序列
S
S
S的贡献和。
我们称一个序列是diff的当且仅当它不存在两个值相同的元素,那么colorful的定义就是包含一个长度为
K
K
K的diff序列。
先预处理出
F
[
i
]
[
j
]
F[i][j]
F[i][j]表示长度为
i
i
i且不colorful的序列中有多少个包含结尾的极长diff子序列长度恰为
j
j
j的。这个可以用简单的DP实现。
我们可以注意到一个事实,所有不colorful的序列中,对于某个长度
j
≤
k
j\leq k
j≤k,结尾是每个长度为
j
j
j的diff序列的方案数是一样的。
也即,对于某个特定的长度为
j
j
j的diff序列,所有长度为
i
i
i的colorful序列中以它结尾的个数为
G
[
i
]
[
j
]
=
∑
l
=
j
k
−
1
F
[
i
]
[
l
]
∏
l
=
1
j
(
k
−
l
+
1
)
G[i][j]=\frac {\sum_{l=j}^{k-1}F[i][l]}{\prod_{l=1}^{j}(k-l+1)}
G[i][j]=∏l=1j(k−l+1)∑l=jk−1F[i][l]。
特判掉
M
M
M是colorful的情况。我们先考虑
M
M
M不是diff的情况,这样不会存在一个长度为
K
K
K的diff序列跨过它。这样只需要枚举
M
M
M所在的开头位置
l
l
l,用预处理的
G
G
G就可以快速计算了。
再考虑
M
M
M是diff的情况,我们仍然枚举开头位置
l
l
l。我们再枚举一下
S
1...
l
+
∣
M
∣
−
1
S_{1...l+|M|-1}
S1...l+∣M∣−1的结尾极长diff序列长度
j
j
j(
j
≥
∣
M
∣
j\geq |M|
j≥∣M∣),那么只需要
S
l
+
∣
M
∣
−
j
.
.
.
N
S_{l+|M|-j...N}
Sl+∣M∣−j...N的开头极长diff序列长度
≥
j
\geq j
≥j即可,利用预处理的信息可以快速计算,最后再除掉
G
[
l
+
∣
M
∣
−
1
]
[
∣
M
∣
]
⋅
G
[
N
−
(
l
+
∣
M
∣
−
j
)
+
1
]
[
j
]
G[l+|M|-1][|M|]\cdot G[N-(l+|M|-j)+1][j]
G[l+∣M∣−1][∣M∣]⋅G[N−(l+∣M∣−j)+1][j]即可。
时间复杂度
O
(
N
K
)
\mathcal O(NK)
O(NK)。
#include <bits/stdc++.h>
#define MOD 1000000007
#define last last2
using namespace std;
typedef long long ll;
ll pow_mod(ll x,int k) {
ll ans=1;
while (k) {
if (k&1) ans=ans*x%MOD;
x=x*x%MOD;
k>>=1;
}
return ans;
}
ll inv[405];
void pre(int k) {
inv[0]=1;
for(int i=1;i<=k;i++) inv[i]=inv[i-1]*(k-i+1)%MOD;
for(int i=1;i<=k;i++) inv[i]=pow_mod(inv[i],MOD-2);
}
int f[25005][405],sum[25005][405];
void dp(int n,int k) {
f[1][1]=sum[1][1]=k;
for(int i=2;i<=n;i++) {
for(int j=1;j<k;j++) f[i][j]=((ll)f[i-1][j-1]*(k-j+1)+sum[i-1][j])%MOD;
for(int j=k-1;j>0;j--) sum[i][j]=(sum[i][j+1]+f[i][j])%MOD;
}
}
int solve1(int n,int m,int l,int r) {
int ans=0;
for(int i=1;i<=n-m+1;i++)
ans=(ans+(ll)sum[i+l-1][l]*sum[n-(i-1)-m+r][r])%MOD;
return ans*inv[l]%MOD*inv[r]%MOD;
}
int solve2(int n,int m,int k) {
int ans=0;
for(int i=1;i<=n-m+1;i++)
for(int j=m;j<min(i+m,k);j++)
ans=(ans+(ll)f[i+m-1][j]*sum[n-(i-1)-m+j][j]%MOD*inv[j])%MOD;
return ans*inv[m]%MOD;
}
int num[25005],last[25005];
int main() {
int n,m,k;
scanf("%d%d%d",&n,&k,&m);
pre(k);
dp(n,k);
for(int i=1;i<=m;i++) scanf("%d",&num[i]);
memset(last,0,sizeof(last));
int l=1,s=pow_mod(k,n-m)*(n-m+1)%MOD;
for(int i=1;i<=m;i++) {
int x=num[i];
if (last[x]>=l) l=last[x]+1;
last[x]=i;
if (i-l+1>=k) {
printf("%d\n",s);
return 0;
}
}
if (l!=1) {
memset(last,0x3f,sizeof(last));
int r=m;
for(int i=m;i>0;i--) {
int x=num[i];
if (last[x]<=r) r=last[x]-1;
last[x]=i;
}
printf("%d\n",(s-solve1(n,m,r,m-l+1)+MOD)%MOD);
}
else printf("%d\n",(s-solve2(n,m,k)+MOD)%MOD);
return 0;
}