题意
https://www.luogu.org/problemnew/show/U40581
思路
这种带两个
∑
\sum
∑ 的题,显然是要单独算贡献了。在树上常见的就几种情况,点的贡献、边的贡献、
LCA
\text{LCA}
LCA的贡献。
考虑链的情况,对于一条边而言,能让它产生贡献的权值区间满足其中即含有在边左边的点,又含有在边右边的点。
把在边左边的点标为
1
1
1,右边的点标为
0
0
0,那么即包含
0
0
0 、又包含
1
1
1 的区间就会产生
1
1
1 的贡献,算这个不好算,我们正难则反,算只有
0
0
0、或只有
1
1
1的区间贡献,再用所有区间去减即可。
我们从左到右枚举选择链上的哪条边,用线段树维护上述的权值区间,从而计算不合法区间的情况,这个线段树应该支持单点修改,全局查询。
到了树上,我们可以类似的,把子树内的点标为
1
1
1 ,把子树外的点标为
0
0
0 ,逐个添加为
1
1
1 的点,但是直接添加又会
T
\text{T}
T,子树内的点??启发式一下就好了嘛。
复杂度两个
log
\log
log ,可以过,但是有少一个
log
\log
log 的写法,线段树合并。
有了这种写法的基础,过渡下来也不难。我们可以干脆对每个点开一个线段树(动态开点),合并的代码如下:
int merge(int k,int p)
{
if(!k||!p)return k|p;
nd[k].lson=merge(nd[k].lson,nd[p].lson);
nd[k].rson=merge(nd[k].rson,nd[p].rson);
add_up(k);
return k;
}
就是把线段树 k k k 向线段树 p p p 合并的过程,现在也暂时用不到这种高级的算法,复杂度不会分析。将来专门练线段树合并时再证复杂度。
代码
启发式合并
#include<bits/stdc++.h>
#define FOR(i,x,y) for(int i=(x),i##END=(y);i<=i##END;++i)
#define DOR(i,x,y) for(int i=(x),i##END=(y);i>=i##END;--i)
typedef long long LL;
using namespace std;
const int N=1e5+5;
template<const int maxn,const int maxm>struct Linked_list
{
int head[maxn],to[maxm],nxt[maxm],tot;
Linked_list(){clear();}
void clear(){memset(head,-1,sizeof(head));}
void add(int u,int v){to[++tot]=v,nxt[tot]=head[u],head[u]=tot;}
#define EOR(i,G,u) for(int i=G.head[u];~i;i=G.nxt[i])
};
Linked_list<N,N<<1>G;
struct SegmentTree
{
LL calc(int n){return 1ll*n*(n+1)/2;}
struct node
{
int L,R,ls,rs;
bool lc,rc;
LL ans;
void reset(bool val)
{
ls=rs=1,lc=rc=val;
ans=1;
}
int lth(){return R-L+1;}
}nd[N<<2];
void build(int k,int L,int R)
{
nd[k].L=L,nd[k].R=R;
if(L==R)
{
nd[k].reset(0);
return;
}
build(k<<1,L,(L+R)>>1);
build(k<<1|1,((L+R)>>1)+1,R);
add_up(k);
}
void add_up(int k)
{
nd[k].lc=nd[k<<1].lc;
nd[k].rc=nd[k<<1|1].rc;
nd[k].ls=nd[k<<1].ls;
nd[k].rs=nd[k<<1|1].rs;
if(nd[k<<1].ls==nd[k<<1].lth()&&nd[k<<1].lc==nd[k<<1|1].lc)nd[k].ls+=nd[k<<1|1].ls;
if(nd[k<<1|1].rs==nd[k<<1|1].lth()&&nd[k<<1|1].rc==nd[k<<1].rc)nd[k].rs+=nd[k<<1].rs;
nd[k].ans=nd[k<<1].ans+nd[k<<1|1].ans;
if(nd[k<<1].rc==nd[k<<1|1].lc)
{
nd[k].ans-=calc(nd[k<<1].rs);
nd[k].ans-=calc(nd[k<<1|1].ls);
nd[k].ans+=calc(nd[k<<1].rs+nd[k<<1|1].ls);
}
}
void update(int k,int x,bool val)
{
if(nd[k].L==nd[k].R)
{
nd[k].reset(val);
return;
}
int mid=(nd[k].L+nd[k].R)>>1;
if(x<=mid)update(k<<1,x,val);
else update(k<<1|1,x,val);
add_up(k);
}
LL query(){return nd[1].ans;}
}ST;
int sz[N],son[N],L[N],R[N],ori[N],ord,n;
LL ans;
void dfs(int u,int f)
{
L[u]=++ord,ori[ord]=u,sz[u]=1,son[u]=0;
EOR(i,G,u)
{
int v=G.to[i];
if(v==f)continue;
dfs(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
R[u]=ord;
}
void update(int L,int R,int val){FOR(i,L,R)ST.update(1,ori[i],val);}
void dsu(int u,int f)
{
EOR(i,G,u)
{
int v=G.to[i];
if(v==f||v==son[u])continue;
dsu(v,u);update(L[v],R[v],0);
}
if(son[u])dsu(son[u],u);
EOR(i,G,u)
{
int v=G.to[i];
if(v==f||v==son[u])continue;
update(L[v],R[v],1);
}
update(L[u],L[u],1);
ans+=1ll*n*(n+1)/2-ST.query();
}
int main()
{
scanf("%d",&n);
FOR(i,1,n-1)
{
int u,v;
scanf("%d%d",&u,&v);
G.add(u,v),G.add(v,u);
}
dfs(1,0);
ST.build(1,1,n);
dsu(1,0);
printf("%lld\n",ans);
return 0;
}
线段树合并
#include<bits/stdc++.h>
#define FOR(i,x,y) for(int i=(x),i##END=(y);i<=i##END;++i)
#define DOR(i,x,y) for(int i=(x),i##END=(y);i>=i##END;--i)
typedef long long LL;
using namespace std;
const int N=1e5+5;
const int NlogN=2e6+5;
inline LL calc(int n){return 1ll*n*(n+1)/2;}
template<const int maxn,const int maxm>struct Linked_list
{
int head[maxn],to[maxm],nxt[maxm],tot;
Linked_list(){clear();}
void clear(){memset(head,-1,sizeof(head));}
void add(int u,int v){to[++tot]=v,nxt[tot]=head[u],head[u]=tot;}
#define EOR(i,G,u) for(int i=G.head[u];~i;i=G.nxt[i])
};
Linked_list<N,N<<1>G;
struct node
{
int lson,rson;
int ls,rs;bool lc,rc;
LL ans;
void reset(int val,int l,int r)
{
ls=rs=ans=r-l+1;
lc=rc=val;
ans=calc(r-l+1);
}
};
struct SegmentTree
{
node nd[NlogN];
int rt[N],tot;
void build(){memset(rt,0,sizeof(rt));}
void create(int &k){nd[k=++tot]=(node){0,0,0,0,0,0,0};}
void add_up(int k,int l,int r)
{
node A,B;
int mid=(l+r)>>1;
if(!nd[k].lson)A.reset(0,l,mid);
else A=nd[nd[k].lson];
if(!nd[k].rson)B.reset(0,mid+1,r);
else B=nd[nd[k].rson];
nd[k].lc=A.lc;
nd[k].rc=B.rc;
nd[k].ls=A.ls;
nd[k].rs=B.rs;
if(A.ls==mid-l+1&&A.lc==B.lc)nd[k].ls+=B.ls;
if(B.rs==r-(mid+1)+1&&B.rc==A.rc)nd[k].rs+=A.rs;
nd[k].ans=A.ans+B.ans;
if(A.rc==B.lc)
{
nd[k].ans-=calc(A.rs);
nd[k].ans-=calc(B.ls);
nd[k].ans+=calc(A.rs+B.ls);
}
}
void update(int &k,int x,int val,int l,int r)
{
if(!k)create(k);
if(l==r)
{
nd[k].reset(val,l,r);
return;
}
int mid=(l+r)>>1;
if(x<=mid)update(nd[k].lson,x,val,l,mid);
else update(nd[k].rson,x,val,mid+1,r);
add_up(k,l,r);
}
int merge(int k,int p,int l,int r)
{
if(!k||!p)return k|p;
int mid=(l+r)>>1;
nd[k].lson=merge(nd[k].lson,nd[p].lson,l,mid);
nd[k].rson=merge(nd[k].rson,nd[p].rson,mid+1,r);
add_up(k,l,r);
return k;
}
LL query(int x){return nd[x].ans;}
}ST;
int n;
LL ans;
void dfs(int u,int f)
{
ST.update(ST.rt[u],u,1,1,n);
EOR(i,G,u)
{
int v=G.to[i];
if(v==f)continue;
dfs(v,u);
ST.rt[u]=ST.merge(ST.rt[u],ST.rt[v],1,n);
}
ans+=calc(n)-ST.query(ST.rt[u]);
}
int main()
{
scanf("%d",&n);
FOR(i,1,n-1)
{
int u,v;
scanf("%d%d",&u,&v);
G.add(u,v),G.add(v,u);
}
ST.build();
dfs(1,0);
printf("%lld\n",ans);
return 0;
}