Time Limit:1s Memory Limit:128MByte
Submissions:528Solved:105
DESCRIPTION
你有n个球,需要把他们放到m个盒子里。要求拥有最多球的盒子唯一,问方案数。
INPUT
一行两个数n、m(n、m≤500)
OUTPUT
一行一个数,表示方案数。答案对998244353取模。
SAMPLE INPUT
5 2
SAMPLE OUTPUT
6
思路:
测试数据很水,很暴力的dp也能过,O(n^3)
dp[i][j]表示i个盒子里面一共放了j个球的情况。
假设球数最多的盒子里面放了k个球,那么剩下的m-1个盒子里面只能放n-k个球,每个盒子最多[0,k-1]个球。
dp[i][j] = ∑dp[i-1][j-x], x∈[0, k]。
用pre[i][j]来求dp[i][j]的前缀和来优化一下。
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <string>
using namespace std;
#define ll long long
const ll mod = 998244353;
ll dp[550][550];
ll pre[550][550];
int main(){
int n, m;
scanf("%d%d", &n, &m);
dp[0][0] = 1;
for(int i=0; i<=n; i++) pre[0][i]=1;
ll ans=0;
for(int k=0; k<=n; k++){
for(int i=1; i<m; i++){
for(int j=0; j<=n; j++){
dp[i][j]=0;
/*for(int x=0; x<k; x++){
ll sum = j-x>=0 ? dp[i-1][j-x]:0;
dp[i][j] = (dp[i][j]+sum)%mod;
}*/
dp[i][j] = (pre[i-1][j] - (j-k>=0?pre[i-1][j-k]:0) + mod)%mod;
pre[i][j] = ((j==0?0:pre[i][j-1]) + dp[i][j])%mod;
}
}
ans = (ans+m*dp[m-1][n-k])%mod;
}
printf("%lld\n", ans);
return 0;
}
看了玲珑oj上面的题解,知道了一种 组合数学+容斥的解法
【ps:玲珑oj上面的题解有点小错误,,这个式子里减去的那部分忘记 *(-1)^k,整个式子要乘以盒子的数量,也就是m,非球的数量n】
利用容斥原理,最终的方案数=总方案数-不合法的方案数
find(i, j)表示将i个球放到j个盒子里的方案数,盒子里面允许为空
find(i, j) = C[i+j-1][j-1] 【这个公式不懂的可以参照这个博客:codeforces397C】
先枚举第一个盒子的数量,假设这是球数量最多的一个盒子,记为x,然后枚举有k个盒子球的数量>=第一个盒子
因此 ans = m*( find(n, m) - ∑((-1)^k) * (C[m-1][k] * find(n-x*(k+1),m-1) ) ), x∈[0,n], k∈[1,m-1]
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <string>
using namespace std;
#define ll long long
const ll mod = 998244353;
ll C[1050][1050];
ll find(int x, int y){
if(x+y-1<0 || y-1<0) return 0;
return C[x+y-1][y-1];
}
int main(){
int n, m;
scanf("%d%d", &n, &m);
C[0][0]=1;
C[1][0] = C[1][1] = 1;
for (int i = 2; i <= 1000; i++){
C[i][0] = 1;
for (int j = 1; j <= 1000; j++)
C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % mod;
}
ll ans = 0;
for(int x=0; x<=n; x++){
for(int k=1; k<m; k++){
ll sum = find(n-x*(k+1), m-1);
ans = (ans + (k%2==1?1:-1)*C[m-1][k]*sum%mod)%mod;
ans = (ans+mod)%mod;
}
}
ans = (find(n, m)-ans+mod)%mod;
ans = m*ans%mod;
printf("%lld\n", ans);
return 0;
}