【题目】
【分析】
一道不错的题
首先,要知道一个东西,即若两条路径相交,则一条路径的 l c a lca lca 必然在另一条路径上
我们每加入一条边,都计算一下之前的边加上它对答案的贡献
现在假设加 ( a , b ) (a,b) (a,b) 条边,具体有一下几种情况:
统计 ( a , b ) (a,b) (a,b) 这条路径上的 l c a lca lca 个数:
大致长这样:
我们给每个点一个点权,代表从这个点到根路径上的
l
c
a
lca
lca 数量(记为
n
u
m
i
num_i
numi)
那么最后的答案显然是 n u m a + n u m b − 2 ∗ n u m l c a num_a+num_b-2*num_{lca} numa+numb−2∗numlca
每次做完更新的时候,要在 l c a lca lca 的子树内的 n u m + 1 num+1 num+1,代表它们到根的 l c a lca lca 数量多了一个(就是 l c a lca lca)
怎么维护这个 n u m num num 呢?可以用一个区间修改,单点查询的树状数组来维护(毕竟代码量小,常数也小)
统计穿过 ( a , b ) (a,b) (a,b) 的 l c a lca lca 的路径数量:
大概长这样:
利用类似于树上差分的思想,在两个端点加一,在
l
c
a
lca
lca 处减二
这样做之后,把一个子树内的值统计出来,就是穿过 l c a lca lca 的路径的数量
因为,如果一条路径在 ( a , b ) (a,b) (a,b) 内部,或者在外面(就是不相交的情况),那么它的值就会被抵消掉( 1 + 1 − 2 1+1-2 1+1−2)
而只有这种一个点在内部, l c a lca lca 和另一个点在外部的情况(也就是合法情况)才会被统计到
那么这个又怎么维护呢?用一个单点修改,区间查询的树状数组就行了
一些要注意的细节:
- 两个树状数组维护的是不一样的东西,不要弄混了
- 注意 l c a lca lca 重复的情况要单独计算,不然会重复(以上两种方法都会算一次)
【代码】
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 1000005
#define M 2000005
#define lowbit(x) x&-x
using namespace std;
int n,m,t,tot;
int first[N],v[M],nxt[M];
int num[N],bit1[N],bit2[N];
int dep[N],size[N],pos[N],fa[N][25];
void add(int x,int y)
{
t++;
nxt[t]=first[x];
first[x]=t;
v[t]=y;
}
void dfs(int x)
{
int i,j;
size[x]=1,pos[x]=++tot;
for(i=1;i<=20;++i)
fa[x][i]=fa[fa[x][i-1]][i-1];
for(i=first[x];i;i=nxt[i])
{
j=v[i];
if(j!=fa[x][0])
{
fa[j][0]=x;
dep[j]=dep[x]+1;
dfs(j);
size[x]+=size[j];
}
}
}
int Lca(int x,int y)
{
int i;
if(dep[x]<dep[y]) swap(x,y);
for(i=20;~i;--i)
if(dep[fa[x][i]]>=dep[y])
x=fa[x][i];
if(x==y) return x;
for(i=20;~i;--i)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
void modify(int *bit,int i,int x)
{
while(i<=n)
{
bit[i]+=x;
i+=lowbit(i);
}
}
int query(int *bit,int i)
{
int ans=0;
while(i)
{
ans+=bit[i];
i-=lowbit(i);
}
return ans;
}
int main()
{
// freopen("access.in","r",stdin);
// freopen("access.out","w",stdout);
int x,y,i;
scanf("%d%d",&n,&m);
for(i=1;i<n;++i)
{
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
dep[1]=1,dfs(1);
long long ans=0;
for(i=1;i<=m;++i)
{
scanf("%d%d",&x,&y);
int lca=Lca(x,y);
ans+=num[lca],num[lca]++;
ans+=query(bit1,pos[x])+query(bit1,pos[y])-2*query(bit1,pos[lca]);
ans+=query(bit2,pos[lca]+size[lca]-1)-query(bit2,pos[lca]-1);
modify(bit1,pos[lca],1);
modify(bit1,pos[lca]+size[lca],-1);
modify(bit2,pos[x],1);
modify(bit2,pos[y],1);
modify(bit2,pos[lca],-2);
}
printf("%lld",ans);
// fclose(stdin);
// fclose(stdout);
return 0;
}