题意
给定n和m,求满足以下条件的数组的价值总和(模998244353):
- 长为n, 1 ≤ a i ≤ m 1\leq a_i\leq m 1≤ai≤m
价值定义 f ( a ) = ∑ i = 1 n [ a i = = m a x ( a ) ] f(a)=\sum_{i=1}^{n}[a_i==max(a)] f(a)=∑i=1n[ai==max(a)]
constrain:
1
≤
n
×
m
≤
1
0
12
1\leq n\times m\leq 10^{12}
1≤n×m≤1012
思路
暴力公式很显然
∑
i
=
1
n
i
C
(
n
,
i
)
∑
k
=
1
m
−
1
k
n
−
i
\sum_{i=1}^{n}iC(n,i)\sum_{k=1}^{m-1}k^{n-i}
i=1∑niC(n,i)k=1∑m−1kn−i
我先考虑的递推
f
(
i
,
j
)
f(i,j)
f(i,j)可以拆分为多种不重复贡献的组合:
- 最大值不超过j-1的组合的贡献, f ( i , j − 1 ) f(i,j-1) f(i,j−1)
- 对比 n = i − 1 n=i-1 n=i−1,记附加位为 x x x,当 x = = m a x ( a ) x==max(a) x==max(a)时, x x x为答案贡献了(i位的方案数)= j i − 1 − ( j − 1 ) i − 1 j^{i-1}-(j-1)^{i-1} ji−1−(j−1)i−1,其他位置的贡献为 f ( i − 1 , j ) − f ( i − 1 , j − 1 ) f(i-1,j)-f(i-1,j-1) f(i−1,j)−f(i−1,j−1),注意到漏了仅x为最大值的贡献 ( j − 1 ) i − 1 (j-1)^{i-1} (j−1)i−1,当 x ≠ m a x ( a ) x\neq max(a) x=max(a)时,固定最大值为 j j j的贡献为 ( j − 1 ) ∗ [ f ( i − 1 , j ) − f ( i − 1 , j − 1 ) ] (j-1)*[f(i-1,j)-f(i-1,j-1)] (j−1)∗[f(i−1,j)−f(i−1,j−1)],右边是固定最大值为 j j j的贡献数,对于每个方案, x x x有 ( j − 1 ) (j-1) (j−1)种取值。
所以有
f
(
i
,
j
)
=
f
(
i
,
j
−
1
)
+
j
∗
[
f
(
i
−
1
,
j
)
−
f
(
i
−
1
,
j
−
1
)
]
+
j
i
−
1
f(i,j)=f(i,j-1)+j*[f(i-1,j)-f(i-1,j-1)]+j^{i-1}
f(i,j)=f(i,j−1)+j∗[f(i−1,j)−f(i−1,j−1)]+ji−1
可以拆一下中括号里的第一项,可以得到一个非递推解析式(还没想到组合逻辑上的解释
f
(
n
,
m
)
=
n
∗
∑
i
=
1
m
i
n
−
1
f(n,m)=n*\sum_{i=1}^{m}i^{n-1}
f(n,m)=n∗i=1∑min−1
由题意
m
i
n
(
n
,
m
)
≤
1
0
6
min(n,m)\leq 10^6
min(n,m)≤106,n比较大时,直接暴力,m比较大时拉格朗日插值求自然数幂和,被板子坑了,太悲伤了。
插值公式:
f
(
x
)
=
∑
i
=
1
k
+
1
y
(
i
)
∏
i
≠
j
x
−
x
j
x
i
−
x
j
f(x)=\sum_{i=1}^{k+1}y(i)\prod_{i\neq j}\frac{x-x_j}{x_i-x_j}
f(x)=i=1∑k+1y(i)i=j∏xi−xjx−xj
这里
y
(
x
)
=
∑
i
=
1
x
i
k
y(x)=\sum_{i=1}^xi^k
y(x)=∑i=1xik,证明
y
(
n
)
y(n)
y(n)是k+1次多项式可以考虑差分。
求自然数幂和,可以证明他是k+1次,x连续:
f
(
n
)
=
∑
i
=
1
k
+
2
(
−
1
)
k
−
i
+
2
f
(
i
)
∑
j
=
1
k
+
2
(
n
−
j
)
(
n
−
i
)
(
i
−
1
)
!
(
k
+
2
−
i
)
!
f(n)=\sum_{i=1}^{k+2}(-1)^{k-i+2}f(i)\frac{\sum_{j=1}^{k+2}(n-j)}{(n-i)(i-1)!(k+2-i)!}
f(n)=i=1∑k+2(−1)k−i+2f(i)(n−i)(i−1)!(k+2−i)!∑j=1k+2(n−j)
代码
#include<bits/stdc++.h>
using namespace std;
#define pow2(X) (1ll<<(X))
#define SIZE(A) ((int)A.size())
#define LENGTH(A) ((int)A.length())
#define ALL(A) A.begin(),A.end()
#define F(i,a,b) for(ll i=a;i<=(b);++i)
#define dF(i,a,b) for(ll i=a;i>=(b);--i)
#define GETPOS(c,x) (lower_bound(ALL(c),x)-c.begin())
#define inf 0x3f3f3f3f
#define infll 0x3f3f3f3f3f3f3f3f
#define pb push_back
#define pr pair<int,int>
#define mkp make_pair
#define fi first
#define se second
#define eps 1e-6
#define PI acos(-1.0)
#define lb lower_bound
#define ub upper_bound
#define bs binary_search
#define FO(x) {freopen(#x".in","r",stdin);freopen(#x".out","w",stdout);}
#define Edg int M=0,fst[SZ],vb[SZ],nxt[SZ];void ad_de(int a,int b){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;}void adde(int a,int b){ad_de(a,b);ad_de(b,a);}
#define Edgc int M=0,fst[SZ],vb[SZ],nxt[SZ],vc[SZ];void ad_de(int a,int b,int c){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;vc[M]=c;}void adde(int a,int b,int c){ad_de(a,b,c);ad_de(b,a,c);}
#define es(x,e) (int e=fst[x];e;e=nxt[e])
#define esb(x,e,b) (int e=fst[x],b=vb[e];e;e=nxt[e],b=vb[e])
#define SZ 666666
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> ipair;
typedef vector<int> VI;
typedef vector<long long> VLL;
typedef vector<vector<long long > > VVLL;
typedef vector<vector<int> > VVI;
typedef vector<double> VD;
typedef vector<string> VS;
const int mods = 998244353;
const int maxn = 1e6+10;
const int N = 1e6+10;
const int E = 1e4+10;
const int lim = 1e9;
ll qpow(ll a,ll b) {ll res=1;a%=mods; assert(b>=0); for(;b;b>>=1){if(b&1)res=res*a%mods;a=a*a%mods;}return res;}
ll lcm(ll a, ll b) {return a / __gcd(a, b) * b;}
int read(){ll x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x*f;}
ll n,m,k;
ll y[maxn], z[maxn], jc[maxn], suf[maxn], pre[maxn];
bool bz[maxn];
void Init() {
memset(z,0,sizeof(z));
memset(bz,0,sizeof(bz));
memset(y,0,sizeof(y));
memset(jc,0,sizeof(jc));
memset(suf,0,sizeof(suf));
memset(pre,0,sizeof(pre));
y[1] = 1, m = k + 2;
F(i, 2, m) {
if (!bz[i])
z[++ z[0]] = i, y[i] = qpow(i, k);
F(j, 1, z[0]) {
if (z[j] * i > m) break;
bz[z[j] * i] = 1;
y[z[j] * i] = (1ll * y[z[j]] * y[i]) % mods;
if (i % z[j] == 0) break;
}
}
F(i, 2, m)
y[i] = (y[i - 1] + y[i]) % mods;
jc[0] = 1;
F(i, 1, m)
jc[i] = 1ll * jc[i - 1] * i % mods;
jc[m] = qpow(jc[m], mods - 2);
dF(i, m - 1, 1)
jc[i] = 1ll * jc[i + 1] * (i + 1) % mods;
}
ll Solve() {
pre[0] = suf[m + 1] = 1ll;
F(i, 1, m)
pre[i] = 1ll * pre[i - 1] * ((n - i+mods)%mods) % mods;
dF(i, m, 1)
suf[i] = 1ll * suf[i + 1] * ((n - i+mods)%mods) % mods;
ll Ans = 0;
F(i, 1, m)
Ans = (Ans + 1ll * y[i] * pre[i - 1] % mods * suf[i + 1] % mods * (((k-i+2)&1) ? (-1ll) : 1ll) * jc[i - 1] % mods * jc[k + 2 - i] % mods) % mods;
return Ans;
}
//12354 1000000000000
int main(){
//freopen("C:\\Users\\Gao\\Desktop\\validation_input\\second_flight_input.txt","r",stdin);
//freopen("C:\\Users\\Gao\\Desktop\\validation_input\\output.txt","w",stdout);
ios_base::sync_with_stdio(0);
int T;
//cin>>T;
T = 100;
F(turn,1,T){
cin>>n>>m;
if(m<=n){
ll ans = 0;
F(i,1,m){
ans = (ans+qpow(i,n-1))%mods;
}
ans = (ans*(n%mods))%mods;
cout<<ans<<endl;
}
else{//n<=m
k = n-1;
n = m;
Init();
ll ans = Solve();
ans = ((ans+mods)*((k+1)%mods))%mods;
cout<<ans<<endl;
}
}
}
/*
*/