题意:
现在有n个线段,每个线段有1/2的可能会被选中,问你被选中的这些线段的交集的长度的平方的期望是多少。
题解:
对于求这种期望我是一窍不通,理解别人的代码也理解了好久才恍恍惚惚好像知道了的样子,难受
首先我们可以将所有的线段分成一个一个小段,然后去做每个小段的贡献:
比如说这三个黑色线段我们就可以将他们分成一个一个红色的小段。
然后对于每一个小段p的贡献:
假设有x个线段包含这个小段
=
∣
p
∣
∗
∑
i
∈
x
s
e
g
m
e
n
t
i
∗
2
x
−
1
2
n
∗
∣
s
e
g
m
e
n
t
i
∣
=|p|*\sum\limits_{i∈x}segment_i*\frac{2^x-1}{2^n}*|segment_i|
=∣p∣∗i∈x∑segmenti∗2n2x−1∗∣segmenti∣
-1是因为不能每个都不选。
然后可以这样做是因为一个长度的贡献可以被分成多个长度贡献的和:
l
e
n
∗
l
e
n
=
l
e
n
∗
l
e
n
1
∗
p
1
+
l
e
n
∗
l
e
n
2
∗
p
2
+
.
.
.
(
l
e
n
1
+
l
e
n
2
+
.
.
.
=
l
e
n
)
len*len=len*len_1*p_1+len*len_2*p_2+...(len1+len2+...=len)
len∗len=len∗len1∗p1+len∗len2∗p2+...(len1+len2+...=len)
那么对于每线段左右端点从小到大排序,线段树维护包含当前枚举到的小段的那些线段的期望。
-1最后再做,就相当于减掉最长长度的平方,因为我们在做每个小段的时候,是算了最长长度的…因为这样比较方便
tql/QAQ.jpg
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=1e6+5;
const ll mod=998244353;
struct Segment{
int l,r,id;
}e[N],me[N];
bool cmpl(Segment a,Segment b){return a.l<b.l;}
bool cmpr(Segment a,Segment b){return a.r<b.r;}
int b[N];
ll mul[N*4],sum[N*4];
void push_up(int root){
sum[root]=(sum[root<<1]+sum[root<<1|1])*mul[root]%mod;
}
void build(int l,int r,int root){
mul[root]=1;
if(l==r){
sum[root]=b[r+1]-b[l];
return ;
}
int mid=l+r>>1;
build(l,mid,root<<1);
build(mid+1,r,root<<1|1);
push_up(root);
}
void update(int l,int r,int root,int ql,int qr,ll v){
if(l>=ql&&r<=qr){
sum[root]=sum[root]*v%mod;
mul[root]=mul[root]*v%mod;
return ;
}
int mid=l+r>>1;
if(mid>=ql)
update(l,mid,root<<1,ql,qr,v);
if(mid<qr)
update(mid+1,r,root<<1|1,ql,qr,v);
push_up(root);
}
ll qpow(ll a,ll b){ll ans=1;for(;b;b>>=1,a=a*a%mod)if(b&1)ans=ans*a%mod;return ans;}
int main()
{
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++){
scanf("%d%d",&e[i].l,&e[i].r),e[i].r++,e[i].id=i;
//me[i]=e[i];
b[i*2-1]=e[i].l,b[i*2]=e[i].r;
}
sort(b+1,b+1+n*2);
int all=unique(b+1,b+1+n*2)-b-1;
for(int i=1;i<=n;i++){
e[i].l=lower_bound(b+1,b+1+all,e[i].l)-b;
e[i].r=lower_bound(b+1,b+1+all,e[i].r)-b;
me[i]=e[i];
}
sort(e+1,e+1+n,cmpl),sort(me+1,me+1+n,cmpr);
build(1,all-1,1);
int l=1,r=1;
ll inv2=qpow(2,mod-2),ans=0;
for(int i=1;i<all;i++){
while(l<=n&&e[l].l<=i)
update(1,all-1,1,e[l].l,e[l].r-1,2),l++;
while(r<=n&&me[r].r<=i)
update(1,all-1,1,me[r].l,me[r].r-1,inv2),r++;
ans=(ans+sum[1]*(b[i+1]-b[i]))%mod;
}
ans=(ans+mod-1ll*(b[all]-b[1])*(b[all]-b[1])%mod)%mod;
ans=ans*qpow(qpow(2,n),mod-2)%mod;
printf("%lld\n",ans);
return 0;
}