题意
解法
考场上想到的
O
(
n
2
)
O(n^2)
O(n2)暴力:
记
f
[
i
]
[
j
]
f[i][j]
f[i][j]表示前i个位置,长度为j的连击出现的期望次数
记
g
[
i
]
[
j
]
g[i][j]
g[i][j]表示第到i个位置为止,目前连击次数为j的概率
转移时有一些细节
#include<bits/stdc++.h>
using namespace std;
const int maxn=5e4+5;
inline int read(){
char c=getchar();int t=0,f=1;
while(!isdigit(c)){if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
return t*f;
}
int n;
double a[maxn],f[2][maxn],g[2][maxn];
int main(){
//freopen("a.in","r",stdin);
//freopen("a.out","w",stdout);
n=read();
for(int i=1;i<=n;i++)scanf("%lf",&a[i]);
f[0][0]=1;
g[0][0]=1;
for(int i=1;i<=n;i++){
for(int j=1;j<=i;j++){
g[i&1][j]=0;
f[i&1][j]=f[(i-1)&1][j];
}
g[i&1][0]=0;
f[i&1][0]=0;
for(int j=1;j<=i;j++){
g[i&1][j]=a[i]*g[(i-1)&1][j-1];
}
for(int j=0;j<i;j++){
g[i&1][0]+=g[(i-1)&1][j]*(1-a[i]);
f[i&1][j]+=g[(i-1)&1][j]*(1-a[i]);
}
}
for(int i=1;i<=n;i++){
f[n&1][i]+=g[n&1][i];
}
f[n&1][0]=0;
for(int i=0;i<=n;i++){
printf("%.12lf\n",f[n&1][i]);
}
return 0;
}
然后是正解:
记p[i]表示前i个位置都连击,在第i+1个位置断连的概率
记q[i]表示恰好从第i个位置开始连击的概率.
记ans[i]表示最终连击长度恰好为i的期望次数
然后发现如果将p或q翻转,ans可以直接由p和q多项式乘法得到
这里首先给出暴力,更便于理解
#include<cstdio>
const int maxn=50005;
double a[maxn],s[maxn],p[maxn],q[maxn],ans[maxn];
int n,i,j;
int main(){
scanf("%d",&n);
s[0]=1;
for(i=1;i<=n;i++){scanf("%lf",&a[i]);s[i]=s[i-1]*a[i];}
for(i=0;i<=n;i++){
if(i>0)p[i]=s[i]*(1-a[i+1]);
if(i<n)q[i]=(1-a[i])/s[i];
}
for(i=1;i<=n;i++)for(j=0;j<i;j++)
ans[i-j]+=p[i]*q[j];//这里就表示[j+1,i]这个区间全连,而j和i+1都断开的概率
for(i=0;i<=n;i++)printf("%.12lf\n",ans[i]);
}
正解
#include<bits/stdc++.h>
using namespace std;
const double pi=acos(-1);
const int maxn=200005;
double a[maxn],s[maxn],p[maxn],q[maxn],ans[maxn];
int n,i,j,lim=1,l,r[maxn];
struct node{
double x,y;
node(double xx=0,double yy=0){x=xx,y=yy;}
}ans1[maxn],p1[maxn],q1[maxn];
node operator +(node a,node b){return node(a.x+b.x,a.y+b.y);}
node operator *(node a,node b){return node(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}
node operator -(node a,node b){return node(a.x-b.x,a.y-b.y);}
void fft(node a[],int lim,int f){
for(int i=0;i<lim;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int i=1;i<lim;i<<=1){
node wn(cos(pi/i),f*sin(pi/i));
for(int j=0;j<lim;j+=(i<<1)){
node wt(1,0);
for(int k=0;k<i;k++,wt=wt*wn){
node x=a[j+k],y=a[j+k+i]*wt;
a[j+k]=x+y;a[j+k+i]=x-y;
}
}
}
}
int main(){
//freopen("a.in","r",stdin);
scanf("%d",&n);
s[0]=1;
for(i=1;i<=n;i++){scanf("%lf",&a[i]);s[i]=s[i-1]*a[i];}
for(i=0;i<=n;i++){
if(i>0)p[i]=s[i]*(1-a[i+1]);
if(i<n)q[i]=(1-a[i])/s[i];
}
reverse(p+1,p+1+n);
while(lim<=(n+n))lim<<=1,l++;
for(int i=1;i<lim;i++)r[i]=(r[i>>1]>>1)|(i&1)<<(l-1);
for(int i=0;i<lim;i++){p1[i].x=p[i];q1[i].x=q[i];}
fft(p1,lim,1);fft(q1,lim,1);
for(int i=0;i<lim;i++)ans1[i]=p1[i]*q1[i];
fft(ans1,lim,-1);
reverse(ans1+1,ans1+1+n);
ans1[0].x=0;
for(i=0;i<=n;i++)printf("%.12lf\n",ans1[i].x/lim);
}