[ARC120F]Wine Thief

题目

传送门 to AtCoder

题意概要
n n n 个物品,你要选 k k k 个。对于所有 ( n k ) {n\choose k} (kn) 个方案,如果任意两个被选择的物品,其下标之差都不小于 D D D,那么这是一个合法方案。一个方案的权值是被选择物品的权值和。求出所有合法方案的权值和。

数据范围与提示
n ≤ 1 0 6 n\le 10^6 n106 k D ≤ n kD\le n kDn 。物品权值不超过 1 0 9 10^9 109

思路

显然,我们只需要固定一瓶酒被偷走,然后计算方案数即可。为了方便,下文可能称呼 “一瓶酒” 为 “一个点” 。

首先我们有一个粗略的认知:如果要在长度为 n n n 的序列中选 k k k 个,使得任意两个的距离不小于 D D D,其方案数是 ( n − ( k − 1 ) ( D − 1 ) k ) {n-(k-1)(D-1)\choose k} (kn(k1)(D1)) 。这是容易发现的:可以提前把这 k k k 个点,和它们之间的 ( k − 1 ) ( D − 1 ) (k-1)(D-1) (k1)(D1) 个点放好,剩下的问题就是,在 k + 1 k+1 k+1 个空隙中分配剩余的点。这就是经典隔板法。

然而,如果我们确定了一个点必须被选择,序列就被分成了两个。这导致我们需要枚举个数,于是复杂度劣化到了 O ( n k ) \mathcal O(nk) O(nk) 。优化点在哪里呢?就在于 序列不完整。如果序列是完整的,那么可以 O ( 1 ) \mathcal O(1) O(1) 算组合数,岂不美哉?

我们钦定一个点被选,然后它左边、右边的 D − 1 D-1 D1 个点都不能选。可以看到,这个整体起到了隔开左右两边的作用。那么我们用 D − 1 D-1 D1 个点去代替它,并且强制这些点不能选,我们就 把序列拼接起来了!得到的就是长度为 n − D n-D nD 的序列,选 k − 1 k-1 k1 个点。

然而我们没法做到强制这些点不被选。我们只能做减法。形式化地,如果用 f n , k ( i ) f_{n,k}(i) fn,k(i) 表示方案数,那么我们有
f n , k ( i ) = ( n − ( k − 1 ) ( D − 1 ) − 1 k − 1 ) − ∑ j = i − D + 1 i − 1 f n − D , k − 1 ( j ) f_{n,k}(i)={n-(k-1)(D-1)-1\choose k-1}-\sum_{j=i-D+1}^{i-1}f_{n-D,k-1}(j) fn,k(i)=(k1n(k1)(D1)1)j=iD+1i1fnD,k1(j)

不过 i + ( D − 1 ) > n i+(D-1)>n i+(D1)>n 时,它就会出一点问题。根据对称性, f n , k ( i ) = f n , k ( n + 1 − i ) f_{n,k}(i)=f_{n,k}(n+1-i) fn,k(i)=fn,k(n+1i),干脆不计算 i > n − D + 1 i>n-D+1 i>nD+1 的值了。由于前半部分必须够多,所以我们需要 n ≥ 2 D n\ge 2D n2D 才行。

考场上我粗略一算:至少 O ( n 2 ) \mathcal O(n^2) O(n2) 的(只看 n n n i i i 两维),还是算了。可是它真的不行吗?

注意到, f n , k f_{n,k} fn,k f n − D , k − 1 f_{n-D,k-1} fnD,k1 类似于一个 序列对应位操作。我们应该猜想一下,是不是可以写成 多项式操作

F n , k ( x ) = ∑ i = 1 n f n , k ( i ) ⋅ x i F_{n,k}(x)=\sum_{i=1}^{n}f_{n,k}(i)\cdot x^i Fn,k(x)=i=1nfn,k(i)xi,那么很容易发现,其实
F n , k ( x ) = ( ⋯ ⋯ ) ∑ j = 1 n x j − F n − D , k − 1 ( x ) ∑ j = 1 D − 1 x j F_{n,k}(x)={\cdots\choose\cdots}\sum_{j=1}^{n}x^j-F_{n-D,k-1}(x)\sum_{j=1}^{D-1}x^j Fn,k(x)=()j=1nxjFnD,k1(x)j=1D1xj

当然,这里的操作是对 x n − D + 2 x^{n-D+2} xnD+2 取模的。显然 F n − D , k − 1 F_{n-D,k-1} FnD,k1 x n − 2 D + 2 x^{n-2D+2} xn2D+2 取模后,位数不够用;把末尾的这几位补上即可。补上的值是可以预处理的,因为只有末尾的 D − 1 D-1 D1 位要补齐,而这些值就是 f n − D , k − 1 ( i )    ( i < D ) f_{n-D,k-1}(i)\;(i<D) fnD,k1(i)(i<D),可以直接计算。

发现 n , k n,k n,k 其实是同时变化的。如果我们把递归的过程拿出来做,其实就是这样一个东西:将原来的多项式加一个补齐多项式 T i ( x ) T_i(x) Ti(x),再乘 G ( x ) = ∑ j = 1 D − 1 − x j G(x)=\sum_{j=1}^{D-1}-x^j G(x)=j=1D1xj,最后每一位加一个数 a i a_i ai 。这东西看上去还是 O ( n d ) \mathcal O({n\over d}) O(dn) F F T \tt FFT FFT,解决不了啊。

一个小妙招:避免取模。我们在计算 F n − D , k − 1 ( x ) G ( x ) F_{n-D,k-1}(x)G(x) FnD,k1(x)G(x) 时,次数是 n − D n-D nD 的,没问题,只是 T i ( x ) G ( x ) T_i(x)G(x) Ti(x)G(x) 超出范围了。由于 T i ( x ) T_i(x) Ti(x) 本就是预处理得到的,我们干脆把 T i ( x ) G ( x ) T_i(x)G(x) Ti(x)G(x) 也预处理一下。这样一来,我们可以直接把括号拆开。类似于
G ( x ) { G ( x ) [ G ( x ) F ( x ) + T 1 ( x ) ] + T 2 ( x ) } + T 3 ( x ) G(x)\{G(x)[G(x)F(x)+T_1(x)]+T_2(x)\}+T_3(x) G(x){G(x)[G(x)F(x)+T1(x)]+T2(x)}+T3(x)

不难发现,就是 F ( x ) , T 1 ( x ) , T 2 ( x ) , T 3 ( x ) F(x),T_1(x),T_2(x),T_3(x) F(x),T1(x),T2(x),T3(x) 作为 G ( x ) 3 , G ( x ) 2 , G ( x ) , 1 G(x)^3,G(x)^2,G(x),1 G(x)3,G(x)2,G(x),1 的系数。这玩意儿就有一个经典的算法了:分治。计算出左边部分的递归值 F ( x ) G ( x ) + T 1 ( x ) F(x)G(x)+T_1(x) F(x)G(x)+T1(x) 之后,一起乘 G ( x ) 2 G(x)^2 G(x)2,然后加上右边部分的递归值。右边的值 T i ( x ) T_i(x) Ti(x) 都是 x m 2 D x^{\frac{m}{2}D} x2mD 的倍数,先一起将其除掉,就可以转化为递归子问题。

常数 a i a_i ai 反而更麻烦。因为它实际上加的是 a i ⋅ I ( x ) a_i\cdot I(x) aiI(x),很烦人。运用同样的思路,倒是也可以做,就是暂时只看一部分,然后剩下的部分再补上。需要求 I ( x ) ∑ a i ⋅ G ( x ) i I(x)\sum a_i\cdot G(x)^i I(x)aiG(x)i,反正很烦人。

假设一共有 m m m 项,整个式子的次数不超过 m D mD mD,时间复杂度是 K ( m ) = 2 K ( m 2 ) + O ( m D log ⁡ m D ) = O ( m D log ⁡ m D log ⁡ m ) K(m)=2K(\frac{m}{2})+\mathcal O(mD\log mD)=\mathcal O(mD\log mD\log m) K(m)=2K(2m)+O(mDlogmD)=O(mDlogmDlogm)

我们知道 n = m D n=mD n=mD,所以总复杂度就是 O ( n log ⁡ n log ⁡ n D ) \mathcal O(n\log n\log{n\over D}) O(nlognlogDn) 。预处理 T i ( x ) G ( x ) T_i(x)G(x) Ti(x)G(x) 也只是 O ( n log ⁡ D ) \mathcal O(n\log D) O(nlogD) 的。

代码

众所周知, A t C o d e r \rm AtCoder AtCoder 是允许 − O f a s t -Ofast Ofast 编译优化选项的。因为 vector 真的非常慢……

#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
# pragma GCC optimize("Ofast")
using namespace std;
# define rep(i,a,b) for(register int i=(a); i<=(b); ++i)
# define drep(i,a,b) for(register int i=(a); i>=(b); --i)
typedef long long int_;
inline int readint(){
	int a = 0; char c = getchar(), f = 1;
	for(; c<'0'||c>'9'; c=getchar())
		if(c == '-') f = -f;
	for(; '0'<=c&&c<='9'; c=getchar())
		a = (a<<3)+(a<<1)+(c^48);
	return a*f;
}
inline int ABS(const int &x){
	return x < 0 ? -x : x;
}

const int Mod = 998244353;
inline int qkpow(int_ b,int q){
	int_ a = 1;
	for(; q; q>>=1,b=b*b%Mod)
		if(q&1) a = a*b%Mod;
	return a;
}

const int LogMod = 32;
int g[2][LogMod], inv2[LogMod];
void prepareNTT(){
	int p = Mod-1, x = 0;
	inv2[0] = 1, inv2[1] = (Mod+1)>>1;
	for(; !(p&1); p>>=1,++x)
		inv2[x+1] = 1ll*inv2[x]*inv2[1]%Mod;
	g[1][x] = qkpow(3,p);
	g[0][x] = qkpow(g[1][x],Mod-2);
	rep(d,0,1) drep(j,x,1) // get DP
		g[d][j-1] = 1ll*g[d][j]*g[d][j]%Mod;
}

const int MaxN = 2000005;
int went[MaxN<<1];
void NTT(int a[],int n,int opt){
	rep(i,1,(1<<n)-1)
		if(i < went[i])
			swap(a[i],a[went[i]]);
	for(int w=1,x=1; x<=n; w<<=1,++x)
	for(int *p=a; p!=a+(1<<n); p+=(w<<1))
	for(int i=0,v=1; i<w; ++i){
		int t = 1ll*v*p[i+w]%Mod;
		p[i+w] = (p[i]+Mod-t)%Mod;
		p[i] = (p[i]+t)%Mod;
		v = 1ll*g[opt][x]*v%Mod;
	}
	if(!opt) rep(i,0,(1<<n)-1)
		a[i] = 1ll*a[i]*inv2[n]%Mod;
}

int tmpa[MaxN<<1], tmpb[MaxN<<1];
struct Poly{
	vector<int> a;
	int& operator[](const int &x){
		return a[x];
	}
	void clear(){ a.clear(); }
	int len() const { return a.size(); }
	void resize(int n){ a.resize(n); }
	void shift(int d){
		if(d > 0){
			resize(len()+d);
			drep(i,len()-1,d)
				a[i] = a[i-d];
			rep(i,0,d-1) a[i] = 0;
		}
		if(d < 0){
			if(-d > len())
				d = -len();
			rep(i,0,len()+d-1)
				a[i] = a[i-d];
			resize(len()+d);
		}
	}
	void adjust(){
		while(!a.empty() && !a.back())
			a.pop_back();
	}

	Poly operator * (Poly b) const {
		int N = 0, _ = len()+b.len()-1;
		for(; (1<<N)<_; ++N);
		memset(tmpa,0,(1<<N)<<2);
		memset(tmpb,0,(1<<N)<<2);
		rep(i,1,(1<<N)-1) // pre-compute went
			went[i] = (went[i>>1]>>1)|((i&1)<<N>>1);
		rep(i,0,len()-1) tmpa[i] = a[i];
		rep(i,0,b.len()-1) tmpb[i] = b[i];
		NTT(tmpa,N,1), NTT(tmpb,N,1);
		rep(i,0,(1<<N)-1) tmpa[i] =
			1ll*tmpa[i]*tmpb[i]%Mod;
		NTT(tmpa,N,0); Poly c;
		c.a.resize(_); // final length
		rep(i,0,_-1) // c.len()-1
			c.a[i] = tmpa[i];
		return c; // product
	}
	Poly& operator += (Poly b){
		if(len() < b.len())
			rep(i,len(),b.len()-1)
				a.push_back(0);
		rep(i,0,b.len()-1)
			a[i] = (a[i]+b[i])%Mod;
		return *this;
	}

	Poly& mul_unit(int k);
};
Poly operator * (const int &x,Poly &p){
	Poly c; c.resize(p.len());
	rep(i,0,c.len()-1)
		c[i] = 1ll*p[i]*x%Mod;
	return c;
}
Poly polypow(Poly &p,int k){
	int N, _ = (p.len()-1)*k+1;
	for(N=1; (1<<N)<_; ++N);
	memset(tmpa,0,(1<<N)<<2);
	rep(i,1,(1<<N)-1) // important
		went[i] = (went[i>>1]>>1)|((i&1)<<N>>1);
	rep(i,0,p.len()-1) tmpa[i] = p[i];
	NTT(tmpa,N,1); Poly c;
	rep(i,0,(1<<N)-1)
		tmpa[i] = qkpow(tmpa[i],k);
	NTT(tmpa,N,0); c.clear();
	rep(i,0,_-1) // final product
		c.a.push_back(tmpa[i]);
	return c;
}
Poly& Poly::mul_unit(int k){
	rep(i,1,k) a.push_back(0);
	rep(i,1,len()-1)
		a[i] = (a[i]+a[i-1])%Mod;
	drep(i,len()-1,k) // too much
		a[i] = (a[i]+Mod-a[i-k])%Mod;
	return *this;
}

int_ jc[MaxN], inv[MaxN];
void prepareC(){
	jc[1] = inv[1] = 1;
	rep(i,2,MaxN-1){
		inv[i] = (Mod-Mod/i)*inv[Mod%i]%Mod;
		jc[i] = jc[i-1]*i%Mod;
	}
	rep(i,2,MaxN-1)
		inv[i] = inv[i-1]*inv[i]%Mod;
	jc[0] = inv[0] = 1;
}
inline int_ getC(int n,int m){
	if(m < 0 || n < m) return 0;
	return jc[n]*inv[m]%Mod*inv[n-m]%Mod;
}

int D; // constant
int_ calc(int n,int k){
	if(n <= 0) return k == 0;
	return getC(n-(k-1)*(D-1),k);
}

int jb[MaxN]; // shift of T / aI
Poly T[MaxN], G;
Poly solve(int l,int r){
	if(l == r) return T[l];
	int mid = l; // split point
	for(int i=l; i<r; ++i)
		if(jb[i] <= (jb[l]+jb[r])/2)
			mid = i; // most right one
	auto L = solve(l,mid);
	auto R = solve(mid+1,r);
	R.shift(jb[mid+1]-jb[l]);
	L = L*polypow(G,r-mid);
	return L += R;
}

int a[MaxN]; // add a_i*I(x)
pair<Poly,Poly> solve_a(int l,int r){
	if(l == r){
		Poly zero; zero.clear();
		Poly unit; unit.resize(1);
		unit.a[0] = a[l];
		return make_pair(zero,unit);
	}
	int mid = l; // spliting point
	for(int i=l; i<r; ++i)
		if(jb[i] <= (jb[l]+jb[r])/2)
			mid = i; // most right one
	auto L = solve_a(l,mid);
	auto R = solve_a(mid+1,r);
	R.first.shift(jb[mid+1]-jb[l]);
	Poly it = R.second;
	it.mul_unit(jb[mid+1]-jb[l]);
	R.first += it;
	Poly xjx = polypow(G,r-mid);
	L.first = L.first*xjx;
	L.second = L.second*xjx;
	L.second += R.second;
	L.first += R.first;
	return L; // that's what we want
}

// # define LOCAL_TEST
# ifdef LOCAL_TEST
	#include <cstdlib>
	#include <ctime>
# endif
int main(){
	# ifdef LOCAL_TEST
		freopen("data.in","r",stdin);
	# endif
	prepareC(), prepareNTT();
	int n = readint(), k = readint();
	D = readint(); int tot = 1;
	while(n-D >= 2*D && k-1 >= 1)
		++ tot, n -= D, -- k;
	T[0].a.push_back(0); // 0*x^0
	for(int i=1,now; i<=n-D+1; ++i){
		rep(j,now=0,k) now = (now+calc(i-D,j)
			*calc(n-(i+D-1),k-1-j))%Mod;
		T[0].a.push_back(now); // F(x)
	}
	G.a.push_back(0); // 0*x^0
	rep(i,1,D-1) G.a.push_back(Mod-1);
	for(int i=1; i<tot; ++i){
		int n0 = n+(i-1)*D, k0 = k+(i-1);
		T[i].resize(D-1);
		rep(j,1,D-1) T[i][D-1-j] =
			calc(n0-(j+D-1),k0-1);
		jb[i] = n0-D+2; // beginning
		a[i] = calc(n0,k0); // all means
		T[i] = T[i]*G; // pre-compute
		T[i].resize(D); // important
	}
	# ifdef LOCAL_TEST
		clock_t start = clock();
	# endif
	Poly res = solve(0,tot-1);
	# ifdef LOCAL_TEST
		printf("solve1 = %ld\n",clock()-start);
	# endif
	for(int i=1; i<tot; ++i)
		jb[i] = (n+i*D)-D+1;
	# ifdef LOCAL_TEST
		start = clock();
	# endif
	Poly dym = solve_a(0,tot-1).first;
	# ifdef LOCAL_TEST
		printf("solve2 = %ld\n",clock()-start);
	# endif
	dym.shift(1); // they've dived x before
	res += dym; // combine them
	n += (tot-1)*D, k += (tot-1);
	if(n < 2*D){ // born deformed
		res.clear(); res.a.push_back(0);
		rep(i,1,n) // calc it manually
			if(i <= D && i+(D-1) > n)
				res.a.push_back(k == 1);
			else if(i <= D) // at the front
				res.a.push_back(calc(n-(i+D-1),k-1));
			else if(i+(D-1) > n) // at the back
				res.a.push_back(calc(i-(D-1),k-1));
			else puts("Um-huh, Math is wrong!");
	}
	int zxy = 0;
	for(int i=1,x; i<=n; ++i){
		x = readint();
		if(i > n-D+1)
			zxy = (zxy+1ll*res[n+1-i]*x)%Mod;
		else zxy = (zxy+1ll*res[i]*x)%Mod;
	}
	printf("%d\n",zxy);
	return 0;
}
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值