最大值(组合,拉格朗日插值)

题意

给定n和m,求满足以下条件的数组的价值总和(模998244353):

  • 长为n, 1 ≤ a i ≤ m 1\leq a_i\leq m 1aim

价值定义 f ( a ) = ∑ i = 1 n [ a i = = m a x ( a ) ] f(a)=\sum_{i=1}^{n}[a_i==max(a)] f(a)=i=1n[ai==max(a)]

constrain:
1 ≤ n × m ≤ 1 0 12 1\leq n\times m\leq 10^{12} 1n×m1012

思路

暴力公式很显然
∑ i = 1 n i C ( n , i ) ∑ k = 1 m − 1 k n − i \sum_{i=1}^{n}iC(n,i)\sum_{k=1}^{m-1}k^{n-i} i=1niC(n,i)k=1m1kni
我先考虑的递推 f ( i , j ) f(i,j) f(i,j)可以拆分为多种不重复贡献的组合:

  • 最大值不超过j-1的组合的贡献, f ( i , j − 1 ) f(i,j-1) f(i,j1)
  • 对比 n = i − 1 n=i-1 n=i1,记附加位为 x x x,当 x = = m a x ( a ) x==max(a) x==max(a)时, x x x为答案贡献了(i位的方案数)= j i − 1 − ( j − 1 ) i − 1 j^{i-1}-(j-1)^{i-1} ji1(j1)i1,其他位置的贡献为 f ( i − 1 , j ) − f ( i − 1 , j − 1 ) f(i-1,j)-f(i-1,j-1) f(i1,j)f(i1,j1),注意到漏了仅x为最大值的贡献 ( j − 1 ) i − 1 (j-1)^{i-1} (j1)i1,当 x ≠ m a x ( a ) x\neq max(a) x=max(a)时,固定最大值为 j j j的贡献为 ( j − 1 ) ∗ [ f ( i − 1 , j ) − f ( i − 1 , j − 1 ) ] (j-1)*[f(i-1,j)-f(i-1,j-1)] (j1)[f(i1,j)f(i1,j1)],右边是固定最大值为 j j j的贡献数,对于每个方案, x x x ( j − 1 ) (j-1) (j1)种取值。

所以有 f ( i , j ) = f ( i , j − 1 ) + j ∗ [ f ( i − 1 , j ) − f ( i − 1 , j − 1 ) ] + j i − 1 f(i,j)=f(i,j-1)+j*[f(i-1,j)-f(i-1,j-1)]+j^{i-1} f(i,j)=f(i,j1)+j[f(i1,j)f(i1,j1)]+ji1
可以拆一下中括号里的第一项,可以得到一个非递推解析式(还没想到组合逻辑上的解释
f ( n , m ) = n ∗ ∑ i = 1 m i n − 1 f(n,m)=n*\sum_{i=1}^{m}i^{n-1} f(n,m)=ni=1min1
由题意 m i n ( n , m ) ≤ 1 0 6 min(n,m)\leq 10^6 min(n,m)106,n比较大时,直接暴力,m比较大时拉格朗日插值求自然数幂和,被板子坑了,太悲伤了。
插值公式:
f ( x ) = ∑ i = 1 k + 1 y ( i ) ∏ i ≠ j x − x j x i − x j f(x)=\sum_{i=1}^{k+1}y(i)\prod_{i\neq j}\frac{x-x_j}{x_i-x_j} f(x)=i=1k+1y(i)i=jxixjxxj
这里 y ( x ) = ∑ i = 1 x i k y(x)=\sum_{i=1}^xi^k y(x)=i=1xik,证明 y ( n ) y(n) y(n)是k+1次多项式可以考虑差分。
求自然数幂和,可以证明他是k+1次,x连续:
f ( n ) = ∑ i = 1 k + 2 ( − 1 ) k − i + 2 f ( i ) ∑ j = 1 k + 2 ( n − j ) ( n − i ) ( i − 1 ) ! ( k + 2 − i ) ! f(n)=\sum_{i=1}^{k+2}(-1)^{k-i+2}f(i)\frac{\sum_{j=1}^{k+2}(n-j)}{(n-i)(i-1)!(k+2-i)!} f(n)=i=1k+2(1)ki+2f(i)(ni)(i1)!(k+2i)!j=1k+2(nj)

代码

#include<bits/stdc++.h>
using namespace std;
#define pow2(X) (1ll<<(X))
#define SIZE(A) ((int)A.size())
#define LENGTH(A) ((int)A.length())
#define ALL(A) A.begin(),A.end()
#define F(i,a,b) for(ll i=a;i<=(b);++i)
#define dF(i,a,b) for(ll i=a;i>=(b);--i)
#define GETPOS(c,x) (lower_bound(ALL(c),x)-c.begin())
#define inf 0x3f3f3f3f
#define infll 0x3f3f3f3f3f3f3f3f
#define pb push_back
#define pr pair<int,int>
#define mkp make_pair
#define fi first
#define se second
#define eps 1e-6
#define PI acos(-1.0)
#define lb lower_bound
#define ub upper_bound
#define bs binary_search
#define FO(x) {freopen(#x".in","r",stdin);freopen(#x".out","w",stdout);}
#define Edg int M=0,fst[SZ],vb[SZ],nxt[SZ];void ad_de(int a,int b){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;}void adde(int a,int b){ad_de(a,b);ad_de(b,a);}
#define Edgc int M=0,fst[SZ],vb[SZ],nxt[SZ],vc[SZ];void ad_de(int a,int b,int c){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;vc[M]=c;}void adde(int a,int b,int c){ad_de(a,b,c);ad_de(b,a,c);}
#define es(x,e) (int e=fst[x];e;e=nxt[e])
#define esb(x,e,b) (int e=fst[x],b=vb[e];e;e=nxt[e],b=vb[e])
#define SZ 666666
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> ipair;
typedef vector<int> VI;
typedef vector<long long> VLL;
typedef vector<vector<long long > > VVLL;
typedef vector<vector<int> > VVI;
typedef vector<double> VD;
typedef vector<string> VS;
const int mods = 998244353;
const int maxn = 1e6+10;
const int N = 1e6+10;
const int E = 1e4+10;
const int lim = 1e9;
ll qpow(ll a,ll b) {ll res=1;a%=mods; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mods;a=a*a%mods;}return res;}
ll lcm(ll a, ll b) {return a / __gcd(a, b) * b;}
int read(){ll x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}

ll n,m,k;
ll y[maxn], z[maxn], jc[maxn], suf[maxn], pre[maxn];
bool bz[maxn];

void Init() {
	memset(z,0,sizeof(z));
	memset(bz,0,sizeof(bz));
	memset(y,0,sizeof(y));
	memset(jc,0,sizeof(jc));
	memset(suf,0,sizeof(suf));
	memset(pre,0,sizeof(pre));
	y[1] = 1, m = k + 2;
	F(i, 2, m) {
		if (!bz[i])
			z[++ z[0]] = i, y[i] = qpow(i, k);
		F(j, 1, z[0]) {
			if (z[j] * i > m) break;
			bz[z[j] * i] = 1;
			y[z[j] * i] = (1ll * y[z[j]] * y[i]) % mods;
			if (i % z[j] == 0) break;
		}
	}
	F(i, 2, m)
		y[i] = (y[i - 1] + y[i]) % mods;
	jc[0] = 1;
	F(i, 1, m)
		jc[i] = 1ll * jc[i - 1] * i % mods;
	jc[m] = qpow(jc[m], mods - 2);
	dF(i, m - 1, 1)
		jc[i] = 1ll * jc[i + 1] * (i + 1) % mods;
}

ll Solve() {
	pre[0] = suf[m + 1] = 1ll;
	F(i, 1, m)
		pre[i] = 1ll * pre[i - 1] * ((n - i+mods)%mods) % mods;
	dF(i, m, 1)
		suf[i] = 1ll * suf[i + 1] * ((n - i+mods)%mods) % mods;

	ll Ans = 0;
	F(i, 1, m)
		Ans = (Ans + 1ll * y[i] * pre[i - 1] % mods * suf[i + 1] % mods * (((k-i+2)&1) ? (-1ll) : 1ll) * jc[i - 1] % mods * jc[k + 2 - i] % mods) % mods;
	return Ans;
}
//12354 1000000000000

int main(){
	//freopen("C:\\Users\\Gao\\Desktop\\validation_input\\second_flight_input.txt","r",stdin);
	//freopen("C:\\Users\\Gao\\Desktop\\validation_input\\output.txt","w",stdout);
    ios_base::sync_with_stdio(0);
    int T;
    //cin>>T; 
	T = 100;
	F(turn,1,T){
		cin>>n>>m;
		if(m<=n){
			ll ans = 0;
			F(i,1,m){
				ans = (ans+qpow(i,n-1))%mods;
			}
			ans = (ans*(n%mods))%mods;
			cout<<ans<<endl;
		}
		else{//n<=m
			k = n-1;
			n = m;
			Init();
			ll ans = Solve();
			ans = ((ans+mods)*((k+1)%mods))%mods;
			cout<<ans<<endl;
		}
	}
}
 
/*
*/
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值