题目
思路
第 i i i 个数字可以是前缀的第 j ( 1 ≤ j ≤ i ) j\;(1\le j\le i) j(1≤j≤i) 大,并且 j j j 序列唯一决定原序列。所以问题转化为,在 0 ≤ x i < i 0\le x_i<i 0≤xi<i 的条件下,求 ∑ x i = k \sum x_i=k ∑xi=k 的方案数。
利用容斥,即规定 t t t 个数字 “爆表”,显然只需要知道它们的和即可。于是问题转化为了,从 { 1 , 2 , 3 , … , n } \{1,2,3,\dots,n\} {1,2,3,…,n} 中选出若干数字,其和为 p p p,求 ( − 1 ) c n t (-1)^{cnt} (−1)cnt 之和。
首先我们可以联想到此题,而且这里每个数字只有一个,所以无需分类,总个数就是 O ( k ) \mathcal O(\sqrt{k}) O(k) 的。但是那道题有一个特点:任意选择数字,都是存在的。而本题中,原封不动地套用公式,就会选出 n + 1 n+1 n+1 等数字去凑出 k k k 。
可是啊,我真傻,真的,我单知道直接套用不行;我不知道稍微改改就行了。我以为会选出
n
+
1
,
n
+
2
n+1,n+2
n+1,n+2 等等,可是它 只会选出
n
+
1
n+1
n+1 。因为我们是递归啊!重新考虑这个过程:
- 当前数字是 1 1 1:先找出 j − 1 j-1 j−1 个数字、和为 i − j i-j i−j 的方案。由于是递归,此处数字不超过 n n n 。然后将它的所有数字 + 1 +1 +1,显然不超过 n + 1 n+1 n+1 。然后第一个位置放 1 1 1 。
- 当前数字大于 1 1 1:先找出 j j j 个数字、和为 i − j i-j i−j 的方案。然后将所有数字 + 1 +1 +1,同理,也是不超过 n + 1 n+1 n+1 的。
所以唯一可能的错误只有一个:最后一个数字恰好为 n + 1 n+1 n+1 。那就减去 f ( j − 1 , i − n − 1 ) f(j-1,i-n-1) f(j−1,i−n−1) 即可。
时间复杂度 O ( k k ) \mathcal O(k\sqrt{k}) O(kk) 。
代码
组合数其实是可以现场算的,具体看代码吧。用阶乘法也不是不行。
#include <cstdio>
#include <iostream>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long int_;
# define rep(i,a,b) for(int i=(a); i<=(b); ++i)
# define drep(i,a,b) for(int i=(a); i>=(b); --i)
inline int readint(){
int a = 0; char c = getchar(), f = 1;
for(; c<'0'||c>'9'; c=getchar())
if(c == '-') f = -f;
for(; '0'<=c&&c<='9'; c=getchar())
a = (a<<3)+(a<<1)+(c^48);
return a*f;
}
inline void writeint(int x){
if(x > 9) writeint(x/10);
putchar((x-x/10*10)^48);
}
const int Mod = 1e9+7;
# define add(a,b) (((a)+(b))%Mod)
const int MaxN = 200005;
const int SqrtN = 448;
int dp[SqrtN][MaxN], inv[MaxN];
int main(){
inv[1] = 1; rep(i,2,MaxN-1)
inv[i] = (0ll+Mod-Mod/i)*inv[Mod%i]%Mod;
int n = readint(), k = readint();
rep(j,dp[0][0]=1,SqrtN-1) rep(i,j,k){
dp[j][i] = add(dp[j][i-j],dp[j-1][i-j]);
if(i > n) dp[j][i] = add(dp[j][i],
Mod-dp[j-1][i-n-1]); // invalid
}
int ans = 0; int_ c = 1;
for(int p=1; p<=n-1; ++p)
c = c*(k+n-p)%Mod*inv[p]%Mod;
for(int i=0,v,sgn=1; i<=k; ++i,sgn=1){
for(int j=v=0; j<SqrtN; ++j,sgn=-sgn)
v = (v+sgn*dp[j][i])%Mod;
ans = (ans+c*v)%Mod;
c = c*inv[k+n-1-i]%Mod*(k-i)%Mod;
}
printf("%d\n",(ans+Mod)%Mod);
return 0;
}