[CF960G] Bandit Blues

Solution

  • O ( n 2 ) O(n^2) O(n2) 做法不会的先去看这个
  • 这里只讲如何快速求第一类斯特林数 s ( n , m ) s(n,m) s(n,m)
  • 首先有递推式: s ( i , j ) = s ( i − 1 , j − 1 ) + ( i − 1 ) ∗ s ( i − 1 , j ) s(i,j)=s(i-1,j-1)+(i-1)*s(i-1,j) s(i,j)=s(i1,j1)+(i1)s(i1,j)
  • 为方便卷积写成这样(第二维和为 j j j): s ( i , j ) = s ( i − 1 , j − 1 ) ∗ b ( i , 1 ) + b ( i , 0 ) ∗ s ( i − 1 , j ) s(i,j)=s(i-1,j-1)*b(i,1)+b(i,0)*s(i-1,j) s(i,j)=s(i1,j1)b(i,1)+b(i,0)s(i1,j)
  • 其中 b ( i , 1 ) = 1 , b ( i , 0 ) = i − 1 b(i,1)=1,b(i,0)=i-1 b(i,1)=1b(i,0)=i1
  • 那么把 s ( i ) s(i) s(i) 看成一个多项式, s ( i , j ) s(i,j) s(i,j) 为这个多项式 x j x^j xj 项的系数,初值: s ( 0 , 0 ) = 1 s(0,0)=1 s(0,0)=1
  • b ( i ) b(i) b(i) 同理
  • 那么 s ( i ) = s ( i − 1 ) ∗ b ( i ) s(i)=s(i-1)*b(i) s(i)=s(i1)b(i)
  • 于是把 s ( 0 ) s(0) s(0) ~ s ( n ) s(n) s(n) 都乘起来,得到的多项式就是 s ( n ) s(n) s(n)
  • 这个多项式的 x i x^i xi 项的系数就是 s ( n , i ) s(n,i) s(n,i)
  • 分治 n t t ntt ntt 即可,时间复杂度 O ( n log ⁡ 2 n ) O(n \log^2n) O(nlog2n)

code

#include <bits/stdc++.h>

using namespace std;

#define ll long long

const int e = 1e6 + 5, mod = 998244353;
int n, a1, b1, fac[e], inv[e], rev[e], lim;
vector<int>g[e];

inline int ksm(int x, int y)
{
	int res = 1;
	while (y)
	{
		if (y & 1) res = (ll)res * x % mod;
		y >>= 1;
		x = (ll)x * x % mod;
	}
	return res;
}

inline void upt(int &x, int y)
{
	x = y;
	if (x >= mod) x -= mod;
}

inline void fft(int n, int *a, int opt)
{
	int i, j, k, r = (opt == 1 ? 3 : (mod + 1) / 3);
	for (i = 0; i < n; i++)
	if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (k = 1; k < n; k <<= 1)
	{
		int w0 = ksm(r, (mod - 1) / (k << 1));
		for (i = 0; i < n; i += (k << 1))
		{
			int w = 1;
			for (j = 0; j < k; j++)
			{
				int b = a[i + j], c = (ll)w * a[i + j + k] % mod;
				upt(a[i + j], b + c);
				upt(a[i + j + k], b + mod - c);
				w = (ll)w * w0 % mod;
			}
		}
	}
}

inline void solve(int l, int r)
{
	if (l >= r) return;
	int i, mid = l + r >> 1;
	solve(l, mid);
	solve(mid + 1, r);
	static int a[266666], b[266666], c[266666];
	int k = 0, la = g[l].size(), lb = g[mid + 1].size();
	lim = 1;
	while (lim < la + lb - 1)
	{
		lim <<= 1;
		k++;
	}
	for (i = 0; i < lim; i++) 
	{
		a[i] = b[i] = 0;
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << k - 1);
	}
	for (i = 0; i < la; i++) a[i] = g[l][i];
	for (i = 0; i < lb; i++) b[i] = g[mid + 1][i];
	fft(lim, a, 1);
	fft(lim, b, 1);
	for (i = 0; i < lim; i++) a[i] = (ll)a[i] * b[i] % mod;
	fft(lim, a, -1);
	int tot = ksm(lim, mod - 2);
	for (i = 0; i < lim; i++) a[i] = (ll)a[i] * tot % mod;
	g[l].clear(); 
	for (i = 0; i < la + lb - 1; i++) g[l].push_back(a[i]); 
}

inline int c(int x, int y)
{
	if (x < y) return 0;
	return (ll)fac[x] * inv[y] % mod * inv[x - y] % mod;
}

int main()
{
	int i;
	cin >> n >> a1 >> b1;
	fac[0] = 1;
	for (i = 1; i <= n; i++) fac[i] = (ll)fac[i - 1] * i % mod;
	inv[n] = ksm(fac[n], mod - 2);
	for (i = n - 1; i >= 0; i--) inv[i] = (ll)inv[i + 1] * (i + 1) % mod;
	int res = c(a1 + b1 - 2, a1 - 1);
	g[0].push_back(1);
	for (i = 1; i <= n; i++)
	{
		g[i].push_back(i - 1);
		g[i].push_back(1);
	}
	solve(0, n - 1);
	if (a1 + b1 - 2 < g[0].size()) res = (ll)res * g[0][a1 + b1 - 2] % mod;
	else res = 0;
	cout << res << endl;
	return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值