Description
给出一棵 n+1 n + 1 个节点的树,要求破坏尽可能少的点使得所给 m m 对点对均不可互达
Input
第一行一整数,之后 n n 行每行两个整数表示一条树边,然后输入一整数 m m ,最后行每行两个整数 u,v u , v 表示需要使得 u,v u , v 不可互达 (3≤n≤104,1≤p≤5⋅104) ( 3 ≤ n ≤ 10 4 , 1 ≤ p ≤ 5 ⋅ 10 4 )
Output
输出需要删去的最少点数
Sample Input
4
1 0
4 2
2 0
3 2
2
1 3
2 1
Sample Output
1
Solution
为使被删掉的点尽可能起作用,对于一个点对要删去影响最大的点,即其 LCA L C A ,对每个点对求出其 LCA L C A ,把查询按点对 LCA L C A 深度降序排,先处理 LCA L C A 深度最深的点对,因为先处理其他点对不能解决该点对的问题,但是先解决该点对的问题可以顺带就解决了其他点对的问题,删去 LCA L C A 后,为了保留下删除该点的影响,把以该点为根的子树全部标记加一,这样以来,对于后面的点对 u,v u , v ,如果 u u 或的标记非零,说明 u u 或的某个祖先被删掉了,且这个被删掉的祖先深度比 u,v u , v 的 LCA L C A 深度深,也即当前点对不需要删点已经被解决掉了,对子树的更新操作求出 dfs d f s 序后用树状数组维护即可
Code
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
#define maxn 10005
#define maxm 50005
int n,m,p[maxn][15],dep[maxn],index,L[maxn],R[maxn];
vector<int>g[maxn];
void dfs(int u,int fa)
{
p[u][0]=fa;
for(int i=1;i<15;i++)p[u][i]=p[p[u][i-1]][i-1];
L[u]=++index;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
if(v==fa)continue;
dep[v]=dep[u]+1;
dfs(v,u);
}
R[u]=index;
}
int lca(int a,int b)
{
int i,j;
if(dep[a]<dep[b])swap(a,b);
for(i=0;(1<<i)<=dep[a];i++);
i--;
for(j=i;j>=0;j--)
if(dep[a]-(1<<j)>=dep[b])
a=p[a][j];
if(a==b) return a;
for(j=i;j>=0;j--)
if(p[a][j]&&p[a][j]!=p[b][j])
a=p[a][j],b=p[b][j];
return p[a][0];
}
struct BIT
{
#define lowbit(x) (x&(-x))
int b[maxn],n;
void init(int _n)
{
n=_n;
for(int i=1;i<=n;i++)b[i]=0;
}
void update(int x,int v)
{
while(x<=n)
{
b[x]+=v;
x+=lowbit(x);
}
}
int query(int x)
{
int ans=0;
while(x)
{
ans+=b[x];
x-=lowbit(x);
}
return ans;
}
}bit;
struct node
{
int u,v,t;
bool operator<(const node&b)const
{
return dep[t]>dep[b.t];
}
}a[maxm];
int main()
{
while(~scanf("%d",&n))
{
n++;
bit.init(n);
for(int i=1;i<=n;i++)g[i].clear();
memset(p,0,sizeof(p));
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
u++,v++;
g[u].push_back(v),g[v].push_back(u);
}
index=0;
dep[1]=0;
dfs(1,0);
scanf("%d",&m);
for(int i=0;i<m;i++)
{
scanf("%d%d",&a[i].u,&a[i].v);
a[i].u++,a[i].v++;
a[i].t=lca(a[i].u,a[i].v);
}
sort(a,a+m);
int ans=0;
for(int i=0;i<m;i++)
{
int u=a[i].u,v=a[i].v,t=a[i].t;
int temp=bit.query(L[u])+bit.query(L[v]);
if(!temp)
{
ans++;
bit.update(L[t],1),bit.update(R[t]+1,-1);
}
}
printf("%d\n",ans);
}
return 0;
}