测试地址:Team Work
题目大意:给定
n,k
n
,
k
,求
∑ni=1Cin⋅ik
∑
i
=
1
n
C
n
i
⋅
i
k
。
做法:本题需要用到二项式反演+第二类斯特林数。
二项式反演的实质是容斥原理,有两种表示形式:
f(n)=∑ni=0(−1)i⋅Cin⋅g(i)
f
(
n
)
=
∑
i
=
0
n
(
−
1
)
i
⋅
C
n
i
⋅
g
(
i
)
与
g(n)=∑ni=0(−1)i⋅Cin⋅f(i)
g
(
n
)
=
∑
i
=
0
n
(
−
1
)
i
⋅
C
n
i
⋅
f
(
i
)
等价,或:
f(n)=∑ni=0Cin⋅g(i)
f
(
n
)
=
∑
i
=
0
n
C
n
i
⋅
g
(
i
)
与
g(n)=∑ni=0(−1)n−i⋅Cin⋅f(i)
g
(
n
)
=
∑
i
=
0
n
(
−
1
)
n
−
i
⋅
C
n
i
⋅
f
(
i
)
等价。
证明待填坑,这里先不写了。
我们发现第二类斯特林数:
S(n,m)=1m!∑mi=0(−1)i⋅Cim⋅(m−i)n
S
(
n
,
m
)
=
1
m
!
∑
i
=
0
m
(
−
1
)
i
⋅
C
m
i
⋅
(
m
−
i
)
n
后面的式子和二项式反演的第二种形式的右边非常相似,而且后面是一个幂函数,因此我们尝试用斯特林数凑出幂函数。
首先有:
m!S(n,m)=∑mi=0(−1)i⋅Cim⋅(m−i)n
m
!
S
(
n
,
m
)
=
∑
i
=
0
m
(
−
1
)
i
⋅
C
m
i
⋅
(
m
−
i
)
n
用
m−i
m
−
i
替换
i
i
,得:
这样这个式子就跟上面的形式完全一样了,所以我们有:
mn=∑ni=0Cim⋅i!⋅S(n,i)
m
n
=
∑
i
=
0
n
C
m
i
⋅
i
!
⋅
S
(
n
,
i
)
细心的同学可能发现我换了和式的上限,这是没有任何问题的,请大家自己证证看。(提示:从组合数和斯特林数有意义的数值区间考虑)
那么我们把这个结论带进要求的式子中去,得到:
ans=∑ni=1Cin⋅ik
a
n
s
=
∑
i
=
1
n
C
n
i
⋅
i
k
=∑ni=0Cin∑kj=0Cji⋅j!⋅S(k,j)
=
∑
i
=
0
n
C
n
i
∑
j
=
0
k
C
i
j
⋅
j
!
⋅
S
(
k
,
j
)
(这一步在
k=0
k
=
0
时会多出一个
1
1
,最后特判减去即可)
显然应该对换的位置,得到:
ans=∑kj=0S(k,j)⋅j!∑ni=0CinCji
a
n
s
=
∑
j
=
0
k
S
(
k
,
j
)
⋅
j
!
∑
i
=
0
n
C
n
i
C
i
j
从后面和式的组合意义考虑,这个式子表达的是,先从
n
n
个里取个,再从
i
i
个里取个的方案数。那么我们不如考虑每一个
j
j
个元素构成的集合产生的贡献,因为是任意取的,那么取其他
i−j
i
−
j
个元素的方案数,就等于
2n−j
2
n
−
j
。所以有:
ans=∑kj=0S(k,j)⋅j!⋅Cjn⋅2n−j
a
n
s
=
∑
j
=
0
k
S
(
k
,
j
)
⋅
j
!
⋅
C
n
j
⋅
2
n
−
j
=∑kj=0S(k,j)⋅n!(n−j)!⋅2n−j
=
∑
j
=
0
k
S
(
k
,
j
)
⋅
n
!
(
n
−
j
)
!
⋅
2
n
−
j
于是我们
O(k2)
O
(
k
2
)
预处理出第二类斯特林数,再
O(k)
O
(
k
)
算出上面的式子即可。
事实上,这道题目是一道更难题目的一小部分,那道题目我看不懂,所以暂时先做这道题。那道题目需要用到
k
k
达到的情况,实际上仍是可做的,因为:
S(k,j)=1j!∑ji=0(−1)i⋅Cij⋅(j−i)k
S
(
k
,
j
)
=
1
j
!
∑
i
=
0
j
(
−
1
)
i
⋅
C
j
i
⋅
(
j
−
i
)
k
=∑ji=0(−1)ii!⋅(j−i)k(j−i)!
=
∑
i
=
0
j
(
−
1
)
i
i
!
⋅
(
j
−
i
)
k
(
j
−
i
)
!
这是一个卷积的形式,可以用FFT/NTT做到
O(klogk)
O
(
k
log
k
)
预处理斯特林数,但这道题
k
k
<script type="math/tex" id="MathJax-Element-8577">k</script>比较小,所以我就不这样写了。
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=1000000007;
ll n,k,S[5010][5010]={0};
ll power(ll a,ll b)
{
ll s=1,ss=a;
while(b)
{
if (b&1) s=s*ss%mod;
b>>=1;ss=ss*ss%mod;
}
return s;
}
int main()
{
scanf("%lld%lld",&n,&k);
S[0][0]=1;
for(ll i=1;i<=k;i++)
for(ll j=1;j<=i;j++)
S[i][j]=(S[i-1][j-1]+j*S[i-1][j])%mod;
ll s1=1,s2=power(2,n),inv=500000004,ans=0;
for(ll i=0;i<=k;i++)
{
ans=(ans+S[k][i]*s1%mod*s2)%mod;
s1=s1*(n-i)%mod;
s2=s2*inv%mod;
}
if (k) printf("%lld",ans);
else printf("%lld",(ans-1+mod)%mod);
return 0;
}