讲解:https://www.luogu.com.cn/blog/P-Wang/query-on-a-tree-explaination
代码:https://blog.csdn.net/Cymbals/article/details/83212059
题意:给一颗树,输入abcd,求a与b的路径连线和c与d的路径连线公共点个数。
因为初始值都是0,省去了建树操作。
线段树维护的依据:
一个节点到其链首端(tid)的dfs序一定是连续的,因此可以转化成区间操作
区间更新时,如果两点的链首端不同,设f1,f2为u,v的链首端标号,并且深度f1>f2,如果f1!=f2,说明在两条链上,又因为上述性质,可以更新区间f1~ u的信息。然后让f1继续上跳即可。
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <vector>
#include <queue>
#include <stack>
using namespace std;
const int maxn=1e5+5;
int n,q;
vector <int> g[maxn];
int sz[maxn],dep[maxn];//节点个数 节点深度
int ch[maxn],fa[maxn];//重儿子编号 父节点编号
int top[maxn],tid[maxn],tid2[maxn];//链首端编号 存储dfs序 dfs序的反函数
int tot;
//树链剖分
void dfs1(int u,int f,int d)
{
sz[u]=1;
fa[u]=f;
dep[u]=d;
for(int i=0; i<g[u].size(); i++)
{
int v=g[u][i];
if(v==f)
continue;
dfs1(v,u,d+1);
sz[u]+=sz[v];
if(ch[u]==-1||sz[v]>sz[ch[u]])
ch[u]=v;
}
}
void dfs2(int u,int tp)
{
top[u]=tp;
tid[u]=++tot;
tid2[tot]=u;
if(ch[u]==-1)
return ;
dfs2(ch[u],tp);
for(int i=0; i<g[u].size(); i++)
{
int v=g[u][i];
if(v!=fa[u]&&v!=ch[u])
dfs2(v,v);
}
}
//线段树维护
struct segmenttree
{
int sum[maxn<<2],lazy[maxn<<2];
void pushdown(int i,int l,int r)
{
if(!lazy[i])
return ;
int mid=(l+r)/2;
sum[i*2]+=(mid-l+1)*lazy[i];
sum[i*2+1]+=(r-mid)*lazy[i];
lazy[i*2]+=lazy[i];
lazy[i*2+1]+=lazy[i];
lazy[i]=0;
}
void pushup(int i)
{
sum[i]=sum[i*2]+sum[i*2+1];
}
void update(int i,int l,int r,int L,int R,int val)
{
if(l>=L&&r<=R)
{
sum[i]+=(r-l+1)*val;
lazy[i]+=val;
return ;
}
int mid=(l+r)/2;
pushdown(i,l,r);
if(L<=mid)
update(i*2,l,mid,L,R,val);
if(R>mid)
update(i*2+1,mid+1,r,L,R,val);
pushup(i);
}
void update(int u,int v,int val)
{
int f1=top[u],f2=top[v];
while(f1!=f2)
{
if(dep[f1]<dep[f2])
{
swap(f1,f2);
swap(u,v);
}
update(1,1,n,tid[f1],tid[u],val);//f1到u的路径dfs序一定是连续的 可区间维护
u=fa[f1];
f1=top[u];
}
if(dep[u]<dep[v])
swap(u,v);
update(1,1,n,tid[v],tid[u],val);
}
int query(int i,int l,int r,int L,int R)
{
if(l>=L&&r<=R)
{
return sum[i];
}
int mid=(l+r)/2;
pushdown(i,l,r);
int ans=0;
if(L<=mid)
ans+=query(i*2,l,mid,L,R);
if(R>mid)
ans+=query(i*2+1,mid+1,r,L,R);
return ans;
}
int query(int u,int v)
{
int f1=top[u],f2=top[v];
int ans=0;
while(f1!=f2)
{
if(dep[f1]<dep[f2])
{
swap(u,v);
swap(f1,f2);
}
ans+=query(1,1,n,tid[f1],tid[u]);
u=fa[f1];
f1=top[u];
}
if(dep[u]<dep[v])
swap(u,v);
ans+=query(1,1,n,tid[v],tid[u]);
return ans;
}
} st;
int main()
{
scanf("%d%d",&n,&q);
for(int i=0; i<n-1; i++)
{
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
memset(ch,-1,sizeof(ch));
dfs1(1,0,0);
dfs2(1,1);
int a,b,c,d;
while(q--)
{
scanf("%d%d%d%d",&a,&b,&c,&d);
st.update(a,b,1);
printf("%d\n",st.query(c,d));
st.update(a,b,-1);
}
return 0;
}