题目描述
Description
Input
一行两个整数n,k
Output
一行一个整数ans,代表答案,模1e9+
Sample Input
Sample Input 1
5 2
Sample Input 2
100 50
Sample Output
Sample Output 1
12
Sample Output 2
400502129
Data Constraint
题解
设f[i]表示至少染色i次后能全部染黑的方案数
那么答案
a
n
s
=
∑
i
⩾
0
f
[
i
]
−
(
n
−
k
+
1
)
!
ans=\sum_{i\geqslant 0}{f[i]}-(n-k+1)!
ans=∑i⩾0f[i]−(n−k+1)!
因为每种排列一共被算了(手贱次数+1)次,所以最后要减去排列数
k=1 or k=n
答案为0
k*2≥n
头尾放完后中间随便放,因为都能染黑
k=2
枚举i,组合数随便算
时间:O(n)
k≥1000
f[i]不好直接算
容斥,对于每个i求出有至少j个间隔≥m的方案,随便算算
这样剩下的只有间隔<k的方案
证明:
设有k个间隔≥k
那么被算的次数=
∑
i
=
0
k
(
−
1
)
i
C
k
i
\sum_{i=0}^{k}{(-1)^{i}C_{k}^{i}}
∑i=0k(−1)iCki
杨辉三角中,只有第一行奇数和-偶数和=1,其余都为0(当前行的奇/偶和=上一行奇+偶,初始奇=1欧=0)
时间:O(n2/k),过不了30%所以要特殊处理
code
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#define fo(a,b,c) for (a=b; a<=c; a++)
#define fd(a,b,c) for (a=b; a>=c; a--)
#define max(a,b) (a>b?a:b)
#define min(a,b) (a<b?a:b)
#define mod 1000000007
#define Mod 1000000005
using namespace std;
long long jc[1000001];
long long Jc[1000001];
long long w[1000001];
long long f[1000001];
int n,m,i,j,k,l;
long long ans;
long long qpower(long long a,int b)
{
long long ans=1;
while (b)
{
if (b&1)
ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
long long C(int n,int m)
{
return jc[n]*Jc[m]%mod*Jc[n-m]%mod;
}
int main()
{
freopen("jian.in","r",stdin);
freopen("jian.out","w",stdout);
scanf("%d%d",&n,&m);
if (n==m)
{
printf("0\n");
return 0;
}
jc[0]=1;jc[1]=1;
Jc[0]=1;Jc[1]=1;
w[1]=1;
fo(i,2,n)
{
w[i]=(long long)mod-(mod/i)*w[mod%i]%mod;
jc[i]=jc[i-1]*i%mod;
Jc[i]=Jc[i-1]*w[i]%mod;
}
if (m==1)
{
printf("0\n");
return 0;
}
if (m+m>=n)
{
f[0]=jc[2]*jc[(n-m+1)-2]%mod;
fo(i,1,(n-m+1)-2)
f[i]=C((n-m+1)-2,i)*jc[i+2]%mod*jc[(n-m+1)-2-i]%mod;
}
else
if (m==2)
{
fo(i,(n-m-m-1)/m+1,n-m-1)
f[i]=C(i+1,n-m-1-i)*jc[i+2]%mod*jc[(n-m+1)-(i+2)]%mod;
}
else
{
fo(i,(n-m-m-1)/m+1,n-m-1)
{
fd(j,min(i+1,(n-m-1-i)/m),0)
if (!(j&1))
f[i]=(f[i]+C(i+1,j)*C(n-m-1-i-j*m+(i+1)-1,(i+1)-1)%mod)%mod;
else
f[i]=(f[i]-C(i+1,j)*C(n-m-1-i-j*m+(i+1)-1,(i+1)-1))%mod;
f[i]=f[i]*jc[i+2]%mod*jc[(n-m+1)-(i+2)]%mod;
}
}
ans=f[0];
fd(i,n-m-1,1)
ans=(ans+f[i])%mod;
ans-=jc[n-m+1];
printf("%lld\n",(ans%mod+mod)%mod);
}