前言
先是容斥
分治套NTT
题意简介
题目大意
现在有
n
n
n个猎人,每个猎人都有一个值
w
i
w_i
wi
进行
n
n
n次杀人,死掉的人不会再被杀
每次杀人过程中第
i
i
i个猎人被杀的概率为
w
i
∑
w
j
\frac{w_i}{\sum w_j}
∑wjwi
问第一个猎人最后一个死的概率
(答案对998244353取模)
数据范围
w
i
>
0
,
1
≤
∑
w
≤
100000
w_i>0,1\le\sum w\le100000
wi>0,1≤∑w≤100000
由于上面这两个限制,我们发现
1
≤
n
≤
100000
1\le n\le100000
1≤n≤100000
题解
我们要求的是没有人在
1
1
1号猎人后死的概率
我们发现直接求不好求,那么考虑容斥
设
f
(
S
)
f(S)
f(S)为
S
S
S集合内的人全都在
1
1
1号猎人后死的概率(不在
S
S
S集合里的人不确定)
显然
f
(
S
)
=
w
1
∑
i
∈
S
w
i
+
w
1
f(S)=\frac{w_1}{\sum_{i\in S}w_i+w_1}
f(S)=∑i∈Swi+w1w1
根据容斥的定义式
∣
A
1
∪
A
2
∪
.
.
.
∪
A
n
∣
=
∑
i
=
1
n
(
−
1
)
i
−
1
∑
∣
T
∣
=
i
,
T
=
{
x
1
.
.
.
x
i
}
∣
A
x
1
∩
A
x
2
∩
.
.
.
∩
A
x
i
∣
|{A_1}\cup{A_2}\cup...\cup{A_n}|=\sum_{i=1}^n(-1)^{i-1}\sum_{|T|=i,T=\{x_1...x_i\}}|{A_{x_1}}\cap{A_{x_2}}\cap...\cap{A_{x_i}}|
∣A1∪A2∪...∪An∣=i=1∑n(−1)i−1∣T∣=i,T={x1...xi}∑∣Ax1∩Ax2∩...∩Axi∣
我们发现我们要求的就是
∑
∣
T
∣
=
i
,
T
=
{
x
1
.
.
.
x
i
}
(
−
1
)
i
f
(
T
)
\sum_{|T|=i,T=\{x_1...x_i\}} (-1)^if(T)
∣T∣=i,T={x1...xi}∑(−1)if(T)
我们发现,直接枚举所有集合复杂度为
O
(
2
n
)
\mathcal O(2^n)
O(2n),显然不能接受
这题在数据范围上是对
∑
w
\sum w
∑w进行限制的,所以不能学傻
我们可以得到一个非常通俗易懂的
O
(
n
2
)
\mathcal O(n^2)
O(n2)DP算法
我们可以用
f
i
,
j
f_{i,j}
fi,j表示前
i
i
i个猎人,
∑
w
=
j
\sum w=j
∑w=j的方案数(差不多就是背包)
转移是枚举下一个元素是否选,是
O
(
1
)
\mathcal O(1)
O(1)的
由于奇偶性不同的情况下符号不一样,所以我们在转移的时候可以使用
−
1
-1
−1的系数,这样就可以统计答案了
这个
O
(
n
2
)
\mathcal O(n^2)
O(n2)算法期望得分50分
考虑生成函数,我们就会发现(生成函数基础应用)
本质上这个
d
p
dp
dp就是一个多项式,每次卷上一个在
w
i
w_i
wi位上有个
−
1
-1
−1,
0
0
0位上有个
1
1
1的多项式
知道结果多项式即可
成功将复杂度升至
O
(
n
2
l
o
g
n
)
\mathcal O(n^2logn)
O(n2logn)
我们发现,多项式乘法的复杂度拆开成两个多项式长度可以表示为
O
(
(
l
e
n
a
+
l
e
n
b
)
l
o
g
(
l
e
n
a
+
l
e
n
b
)
)
\mathcal O((lena+lenb)log(lena+lenb))
O((lena+lenb)log(lena+lenb))
直接分治即可
具体做法的代码(缩略版,其实是我不知道写伪代码的正确姿势)
polynomial calc(int l,int r)
{
if(l==r)return a[l];
int mid=(l+r)/2;
return calc(l,mid)*calc(mid+1,r);
}
将复杂度降到
O
(
n
l
o
g
2
n
)
\mathcal O(nlog^2n)
O(nlog2n)
证明:
l
o
g
(
l
e
n
a
+
l
e
n
b
)
⇔
l
o
g
n
log(lena+lenb)\Leftrightarrow logn
log(lena+lenb)⇔logn(这里的
n
n
n为100000)
每个多项式只会参与多项式乘法
l
o
g
n
logn
logn次,一次的复杂度消耗为长度乘以log
故总复杂度为
O
(
n
l
o
g
2
n
)
\mathcal O(nlog^2n)
O(nlog2n)
代码
#include<cstdio>
#include<cctype>
#include<cstring>
#include<vector>
namespace fast_IO
{
const int IN_LEN=10000000,OUT_LEN=10000000;
char ibuf[IN_LEN],obuf[OUT_LEN],*ih=ibuf+IN_LEN,*oh=obuf,*lastin=ibuf+IN_LEN,*lastout=obuf+OUT_LEN-1;
inline char getchar_(){return (ih==lastin)&&(lastin=(ih=ibuf)+fread(ibuf,1,IN_LEN,stdin),ih==lastin)?EOF:*ih++;}
inline void putchar_(const char x){if(oh==lastout)fwrite(obuf,1,oh-obuf,stdout),oh=obuf;*oh++=x;}
inline void flush(){fwrite(obuf,1,oh-obuf,stdout);}
}
using namespace fast_IO;
#define getchar() getchar_()
#define putchar(x) putchar_((x))
typedef long long ll;
#define rg register
template <typename T> inline T max(const T a,const T b){return a>b?a:b;}
template <typename T> inline T min(const T a,const T b){return a<b?a:b;}
template <typename T> inline T mind(T&a,const T b){a=a<b?a:b;}
template <typename T> inline T maxd(T&a,const T b){a=a>b?a:b;}
template <typename T> inline T abs(const T a){return a>0?a:-a;}
template <typename T> inline void swap(T&a,T&b){T c=a;a=b;b=c;}
template <typename T> inline void swap(T*a,T*b){T c=a;a=b;b=c;}
template <typename T> inline T gcd(const T a,const T b){if(!b)return a;return gcd(b,a%b);}
template <typename T> inline T square(const T x){return x*x;};
template <typename T> inline void read(T&x)
{
char cu=getchar();x=0;bool fla=0;
while(!isdigit(cu)){if(cu=='-')fla=1;cu=getchar();}
while(isdigit(cu))x=x*10+cu-'0',cu=getchar();
if(fla)x=-x;
}
template <typename T> void printe(const T x)
{
if(x>=10)printe(x/10);
putchar(x%10+'0');
}
template <typename T> inline void print(const T x)
{
if(x<0)putchar('-'),printe(-x);
else printe(x);
}
const int maxn=2097153,mod=998244353;
inline int Md(const int x){return x>=mod?x-mod:x;}
template<typename T>
inline int pow(int x,T y)
{
rg int res=1;x%=mod;
for(;y;y>>=1,x=(ll)x*x%mod)if(y&1)res=(ll)res*x%mod;
return res;
}
int W_[maxn],FW_[maxn],ha[maxn],hb[maxn];
inline void init(const int x)
{
rg int tim=0,lenth=1;
while(lenth<x)lenth<<=1,tim++;
for(rg int i=1;i<lenth;i++)
{
W_[i]=pow(3,(mod-1)/i/2);
FW_[i]=pow(W_[i],mod-2);
}
}
int L;
inline void NTT(int*A,const int fla)
{
for(rg int i=0,j=0;i<L;i++)
{
if(i>j)swap(A[i],A[j]);
for(rg int k=L>>1;(j^=k)<k;k>>=1);
}
for(rg int i=1;i<L;i<<=1)
{
const int w=fla==-1?FW_[i]:W_[i];
for(rg int j=0,J=i<<1;j<L;j+=J)
{
int K=1;
for(rg int k=0;k<i;k++,K=(ll)K*w%mod)
{
const int x=A[j+k],y=(ll)A[j+k+i]*K%mod;
A[j+k]=Md(x+y),A[j+k+i]=Md(mod+x-y);
}
}
}
}
struct poly
{
std::vector<int>A;
inline int&operator[](const int x){return A[x];}
inline void clear(){A.clear();}
inline unsigned int size(){return A.size();}
void RE(const int x)
{
A.resize(x);
for(rg int i=0;i<x;i++)A[i]=0;
}
void readin(const int MAX)
{
A.resize(MAX);
for(rg int i=0;i<MAX;i++)read(A[i]);
}
void putout()
{
for(rg int i=0;i<A.size();i++)print(A[i]),putchar(' ');
}
inline poly operator *(const poly b)const
{
L=1;const int RES=A.size()+b.A.size()-1;
while(L<RES)L<<=1;
poly c;c.A.resize(RES);
memset(ha,0,sizeof(int)*L);
memset(hb,0,sizeof(int)*L);
for(rg int i=0;i<A.size();i++)ha[i]=A[i];
for(rg int i=0;i<b.A.size();i++)hb[i]=b.A[i];
NTT(ha,1),NTT(hb,1);
for(rg int i=0;i<L;i++)ha[i]=(ll)ha[i]*hb[i]%mod;
NTT(ha,-1);
const int inv=pow(L,mod-2);
for(rg int i=0;i<RES;i++)c.A[i]=(ll)ha[i]*inv%mod;
return c;
}
}a[100001];
int n,w[100001];
void fz(const int l,const int r)
{
if(l==r)
{
a[l].RE(w[l]+1);
a[l][0]=1;
a[l][w[l]]=mod-1;
return;
}
const int mid=(l+r)>>1;
fz(l,mid),fz(mid+1,r);
a[l]=a[l]*a[mid+1];
a[mid+1].A.clear();
}
int ans;
int main()
{
init(maxn-2);
read(n);
for(rg int i=0;i<n;i++)read(w[i]);
fz(1,n-1);
for(rg int i=0;i<a[1].size();i++)ans=Md(ans+(ll)w[0]*pow(i+w[0],mod-2)%mod*a[1][i]%mod);
print(ans);
return flush(),0;
}
总结
写了一个多项式乘法的板子(还没填更多的功能),剩下的就比较清真了
终于有一篇博客写到分治+FFT了
想到容斥就有50分,去年的我50分都没