题面
题意
有一棵有n个点的树,上面有m条链,两个点可以互达当且仅当存在一条两点都在上面的链.
问有几对点可以互达.
做法
对每个点考虑它对答案的贡献,可以发现,若点x在链
a
1
,
b
1
;
a
2
,
b
2
.
.
.
.
.
a_1,b_1;a_2,b_2.....
a1,b1;a2,b2.....上,则与x可以互达的点恰好都在点
a
1
,
b
1
,
a
2
,
b
2
.
.
.
.
.
a_1,b_1,a_2,b_2.....
a1,b1,a2,b2.....构成的虚树上(包括被压缩在边上的点),因此我们可以直接将所有点根据dfs序排序,然后用线段树来储存,线段树中的每个叶子节点表示:该点是否是包含点x的链上的某个端点,然后在up时即可统计答案:左右两边的答案之和,加上两个联通块之间的点数再,减去左边dfs序最大点与右边dfs序最小的点的lca与x的深度差(去重).
然后用线段树合并即可处理,再用st表来实现
O
(
1
)
O(1)
O(1)求lca,这样总的时间复杂度即为
O
(
n
∗
l
o
g
2
n
)
O(n*log_2n)
O(n∗log2n)
代码
#include<bits/stdc++.h>
#define ll long long
#define LG 17
#define N 100100
using namespace std;
int n,m,tt,deep[N],dfn[N],in[N];
vector<int>to[N];
namespace LCA
{
int tmp,pos[N],lg[N<<1],ou[N<<1],mn[N<<1][20];
void dfs(int now,int last)
{
int i,t;
ou[++tmp]=now;
dfn[++tt]=now;
pos[now]=tmp;
in[now]=tt;
for(i=0;i<to[now].size();i++)
{
t=to[now][i];
if(t==last) continue;
deep[t]=deep[now]+1;
dfs(t,now);
ou[++tmp]=now;
}
}
inline void pre()
{
int i,j;
deep[1]=1;
dfs(1,-1);
for(i=1;i<=tmp;i++) mn[i][0]=ou[i];
for(j=1;j<=LG;j++)
{
for(i=1;i+(1 << (j-1))<=tmp;i++)
{
if(deep[mn[i][j-1]]<deep[mn[i+(1 << (j-1))][j-1]]) mn[i][j]=mn[i][j-1];
else mn[i][j]=mn[i+(1 << (j-1))][j-1];
}
}
for(i=2;i<=tmp;i++) lg[i]=lg[i>>1]+1;
}
inline int ask(int u,int v)
{
if(!u || !v) return u+v;
u=pos[u],v=pos[v];
if(u>v) swap(u,v);
int l=lg[v-u+1];
return deep[mn[u][l]]<deep[mn[v-(1 << l)+1][l]]?mn[u][l]:mn[v-(1 << l)+1][l];
}
}
int rt[N];
ll ans;
struct Node
{
int ls,rs,dn,sum,ld,rd,cnt;
}node[N*40];
vector<int>ad[N],del[N];
inline void up(int now)
{
int L=node[now].ls,R=node[now].rs;
if(!node[L].cnt)
{
node[now].dn=node[R].dn;
node[now].sum=node[R].sum;
node[now].ld=node[R].ld;
node[now].rd=node[R].rd;
node[now].cnt=node[R].cnt;
return;
}
if(!node[R].cnt)
{
node[now].dn=node[L].dn;
node[now].sum=node[L].sum;
node[now].ld=node[L].ld;
node[now].rd=node[L].rd;
node[now].cnt=node[L].cnt;
return;
}
node[now].dn=LCA::ask(node[L].dn,node[R].dn);
node[now].sum=node[L].sum+node[R].sum+deep[node[L].dn]+deep[node[R].dn]-2*deep[node[now].dn]-1-(deep[LCA::ask(node[L].rd,node[R].ld)]-deep[node[now].dn]);
node[now].ld=node[L].ld;
node[now].rd=node[R].rd;
node[now].cnt=node[L].cnt+node[R].cnt;
}
int mg(int u,int v,int l,int r)
{
if(!u || !v) return u+v;
if(l==r)
{
node[u].cnt+=node[v].cnt;
return u;
}
int mid=((l+r)>>1);
node[u].ls=mg(node[u].ls,node[v].ls,l,mid);
node[u].rs=mg(node[u].rs,node[v].rs,mid+1,r);
up(u);
return u;
}
void add(int now,int l,int r,int u,int v)
{
if(l==r)
{
node[now].cnt+=v;
if(!node[now].cnt) return;
node[now].dn=node[now].ld=node[now].rd=dfn[l];
node[now].sum=1;
return;
}
int mid=((l+r)>>1);
if(u<=mid)
{
if(!node[now].ls) node[now].ls=++tt;
add(node[now].ls,l,mid,u,v);
}
else
{
if(!node[now].rs) node[now].rs=++tt;
add(node[now].rs,mid+1,r,u,v);
}
up(now);
}
void dfs(ll now,ll last)
{
ll i,t;
for(i=0;i<to[now].size();i++)
{
t=to[now][i];
if(t==last) continue;
dfs(t,now);
rt[now]=mg(rt[now],rt[t],1,n);
}
for(i=0;i<ad[now].size();i++)
{
t=ad[now][i];
add(rt[now],1,n,t,1);
}
if(node[rt[now]].cnt) ans+=node[rt[now]].sum-1;
for(i=0;i<del[now].size();i++)
{
t=del[now][i];
add(rt[now],1,n,t,-2);
}
}
int main()
{
int i,j,p,q,t;
cin>>n>>m;
for(i=1;i<n;i++)
{
scanf("%d%d",&p,&q);
to[p].push_back(q);
to[q].push_back(p);
}
LCA::pre();
for(i=1;i<=m;i++)
{
scanf("%d%d",&p,&q);
t=LCA::ask(p,q);
ad[p].push_back(in[p]),ad[p].push_back(in[q]);
ad[q].push_back(in[p]),ad[q].push_back(in[q]);
del[t].push_back(in[p]),del[t].push_back(in[q]);
}
for(i=1;i<=n;i++) rt[i]=i;
tt=n;
dfs(1,-1);
cout<<ans/2;
}