树上的差分,顾名思义就是将原本数列的差分操作改成了树上节点和边的差分操作,利用到“差分序列的前缀和就是原序列”的重要性质将树上的操作简化。即区间操作转化为路径操作,前缀和转换为子树和。
以本题为例,节点x和y之间所有的节点都增加1可以转换为节点x和y加1,x和y的公共祖先减2。相当于把x到y当成一个区间来操作,统计时只需算子树和即可。此题比较麻烦的是公共祖先节点要单独分类,在单独开一个数组记每个点当了几次公共祖先,统计时再加上这个值就可以了。
代码如下:
#include <bits/stdc++.h>
using namespace std;
const int N=50001;
queue<int> q;
int n,k,x,y,t,d[N],f[N][20],a[N],b[N],ans[N],anss,sum[N];
int ver[2*N],Next[N*2],head[N],tot;
void add(int x,int y)
{
ver[++tot]=y;
Next[tot]=head[x],head[x]=tot;
}
void bfs()
{
q.push(1);d[1]=1;
while(q.size())
{
int x=q.front();q.pop();
for(int i=head[x];i;i=Next[i])
{
int y=ver[i];
if(d[y]) continue;
d[y]=d[x]+1;
f[y][0]=x;
for(int j=1;j<=t;j++)
f[y][j]=f[f[y][j-1]][j-1];
q.push(y);
}
}
}
int lca(int x,int y)
{
if(d[x]>d[y]) swap(x,y);
for(int i=t;i>=0;i--)
if(d[f[y][i]]>=d[x]) y=f[y][i];
if(x==y) return x;
for(int i=t;i>=0;i--)
if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
void dfs(int x,int fa)
{
ans[x]=a[x]+b[x];
for(int i=head[x];i;i=Next[i])
{
int y=ver[i];
if(y==fa) continue;
dfs(y,x);
sum[x]+=sum[y]+a[y];
}
ans[x]+=sum[x];
}
int main()
{
scanf("%d%d",&n,&k);
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
t=log(n)/log(2)+1;
bfs();
for(int i=1;i<=k;i++)
{
scanf("%d%d",&x,&y);
int k=lca(x,y);
a[x]++;a[y]++;a[k]-=2;
b[k]++;
}
dfs(1,-1);
for(int i=1;i<=n;i++)
anss=max(anss,ans[i]);
printf("%d",anss);
}