分类讨论
一、,答案是
二、,答案是
三、其他情况,容易发现每行都至少有一个,但是至少有一列没有棋子的方案数 = 每列都至少有一个,但是至少有一行没有棋子的方案数。所以我们可以假设每行都有棋子,算出一个方案数,再乘2就好了。画图容易发现是在每行都有棋子的基础上,把个棋子放进
列。设
表示把
个棋子放进
列且这
列都至少有一个棋子的方案数,这个问题就是把
个不同的小球放进
个不同的盒子且不允许空盒子的方案数。
。
是第二类斯特林数。这类情况的答案是
。其中
表示从
行中选择
列用来放棋子。
主要问题是怎么快速求第二类斯特林数,因为,不能直接递推,考虑通项。主要有两种方法。
(1)
#include<bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
#define sz(x) x.size()
#define cl(x) x.clear()
#define all(x) x.begin() , x.end()
#define rep(i , x , n) for(int i = x ; i <= n ; i ++)
#define per(i , n , x) for(int i = n ; i >= x ; i --)
#define mem0(x) memset(x , 0 , sizeof(x))
#define debug(x) cout << '*' << x << '\n'
#define ddebug(x , y) cout << '*' << x << ' ' << y << '\n'
#define ios std::ios::sync_with_stdio(false) , cin.tie(0)
using namespace std ;
typedef long long ll ;
typedef pair<int , int> pii ;
const ll mod = 998244353 ;
const int maxn = 2e5 + 5 ;
ll qpow(int a , int b)
{
if(b < 0) return 0 ;
int ans = 1 ;
a %= mod ;
while(b)
{
if(b & 1) ans = 1ll * ans * a % mod ;
b >>= 1 , a = 1ll * a * a % mod ;
}
return ans % mod ;
}
ll fac[maxn] ;
ll inv[maxn] ;
void init()
{
fac[0] = fac[1] = inv[0] = inv[1] = 1 ;
for(int i = 2 ; i <= 2e5 ; i ++)
{
fac[i] = fac[i - 1] * i % mod ;
inv[i] = -inv[mod % i] * (mod / i) % mod ;
while(inv[i] < 0) inv[i] += mod ;
}
for(int i = 2 ; i <= 2e5 ; i ++)
inv[i] = inv[i] * inv[i - 1] % mod ;
}
ll C(int n , int k)
{
return fac[n] * inv[n - k] % mod * inv[k] % mod ;
}
ll S(int n , int k)
{
ll res = 0 ;
for(int i = 0 ; i <= k ; i ++)
{
res += qpow(-1 , i) * C(k , i) % mod * qpow(k - i , n) % mod ;
res %= mod ;
if(res < 0) res += mod ;
}
return res * inv[k] % mod ;
}
int main()
{
ios ;
int n , k ;
cin >> n >> k ;
if(k <= n - 1)
{
init() ;
if(k == 0) cout << fac[n] << '\n' ;
else cout << C(n , n - k) * fac[n - k] % mod * S(n , n - k) % mod * 2ll % mod << '\n' ;
}
else cout << "0\n" ;
return 0 ;
}
(2)花费的代价求一行第二类斯特林数。
#include<bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
#define sz(x) x.size()
#define cl(x) x.clear()
#define all(x) x.begin() , x.end()
#define rep(i , x , n) for(int i = x ; i <= n ; i ++)
#define per(i , n , x) for(int i = n ; i >= x ; i --)
#define mem0(x) memset(x , 0 , sizeof(x))
#define debug(x) cout << '*' << x << '\n'
#define ddebug(x , y) cout << '*' << x << ' ' << y << '\n'
#define ios std::ios::sync_with_stdio(false) , cin.tie(0)
using namespace std ;
typedef long long ll ;
typedef pair<int , int> pii ;
const ll mod = 998244353 ;
const int maxn = 2e5 + 5 ;
ll qpow(int a , int b)
{
if(b < 0) return 0 ;
int ans = 1 ;
a %= mod ;
while(b)
{
if(b & 1) ans = 1ll * ans * a % mod ;
b >>= 1 , a = 1ll * a * a % mod ;
}
return ans % mod ;
}
ll fac[maxn] ;
ll inv[maxn] ;
void init()
{
fac[0] = fac[1] = inv[0] = inv[1] = 1 ;
for(int i = 2 ; i <= 2e5 ; i ++)
{
fac[i] = fac[i - 1] * i % mod ;
inv[i] = -inv[mod % i] * (mod / i) % mod ;
while(inv[i] < 0) inv[i] += mod ;
}
for(int i = 2 ; i <= 2e5 ; i ++)
inv[i] = inv[i] * inv[i - 1] % mod ;
}
ll C(int n , int k)
{
return fac[n] * inv[n - k] % mod * inv[k] % mod ;
}
struct NTT
{
int n , m ;
ll a[maxn << 2] , b[maxn << 2] ;
ll up , l ;
ll pos[maxn << 2] ;
ll powmod(ll a , ll b)
{
ll ans = 1 ;
while(b)
{
if(b & 1) ans = ans * a % mod ;
a = a * a % mod ;
b >>= 1 ;
}
return ans ;
}
void init(int n , int m)
{
up = 1 , l = 0 ;
while(up < (n + m)) up <<= 1 , l ++ ;
rep(i , 0 , up - 1) pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << (l - 1)) , a[i] = b[i] = 0 ;
}
void solve(ll *a , int mode)
{
rep(i , 0 , up - 1) if(i < pos[i]) swap(a[i] , a[pos[i]]) ;
for(int i = 1 ; i < up ; i <<= 1)
{
ll gn = powmod(3 , (mod - 1) / (i << 1)) ;
if(mode == -1) gn = powmod(gn , mod - 2) ;
for(int j = 0 ; j < up ; j += (i << 1))
{
ll g = 1 ;
for(int k = 0 ; k < i ; k ++ , g = g * gn % mod)
{
ll x = a[j + k] , y = g * a[j + k + i] % mod ;
a[j + k] = (x + y) % mod , a[j + k + i] = (x - y + mod) % mod ;
}
}
}
if(mode == -1)
{
ll invup = powmod(up , mod - 2) ;
rep(i , 0 , up - 1) a[i] = a[i] * invup % mod ;
}
}
} ntt ;
ll qpow(ll a , ll b) //快速幂
{
if(b < 0) return 0 ;
ll ans = 1 ;
a %= mod ;
while(b)
{
if(b & 1) ans = (ans * a) % mod ;
b >>= 1 , a = (a * a) % mod ;
}
return ans % mod ;
}
ll S(int n , int k)
{
ntt.init(n + 1 , n + 1) ;
rep(i , 0 , n)
{
if(i % 2 == 0) ntt.a[i] = inv[i] ;
else ntt.a[i] = -inv[i] ;
}
rep(i , 0 , n) ntt.b[i] = qpow(i , n) * inv[i] % mod ;
ntt.solve(ntt.a , 1) ;
ntt.solve(ntt.b , 1) ;
rep(i , 0 , ntt.up - 1) ntt.a[i] *= ntt.b[i] , ntt.a[i] %= mod ;
ntt.solve(ntt.a , -1) ;
return ntt.a[k] ;
}
int main()
{
ios ;
int n , k ;
cin >> n >> k ;
if(k <= n - 1)
{
init() ;
if(k == 0) cout << fac[n] << '\n' ;
else cout << C(n , n - k) * fac[n - k] % mod * S(n , n - k) % mod * 2ll % mod << '\n' ;
}
else cout << "0\n" ;
return 0 ;
}