题目
n(n<=2e5)个点的树,q(q<=2e5)次询问,
每次询问会给出一个点x和k个要删掉的点,
在树上删掉这k个点和k个点相连的边后,
询问在剩下的若干个连通块中,x能到的最远的点的距离
各个询问是独立的,也就是说第i轮删掉的点会在第i+1轮加回来
sumk不超过2e5
思路来源
官方题解
树的直径的性质
1. 当合并两个区间(即合并两棵树)时,
新的树的直径的两个端点,一定是在原来两棵树直径的四个点里选两个点
2. x所在的连通块能到的最远点,一定是x这个连通块的直径的两个端点中的一个
dfs序的性质
a的dfs序对应[in[a],out[a]],b的区间对应[in[b],out[b]]
若in[a]<in[b]<out[a],说明b在a的子树里,一定有in[a]<in[b]<=out[b]<=out[a]
题解
懒得写代码了,直接抄的官方题解的代码,不过确实看懂了
首先根据dfs序或者欧拉序(欧拉游览树ETT),dfs序一开始是连续的
删掉k个点后,把dfs序切成若干个区间,区间数是大致是k-2k级别的,
剩下的区间都是x的可达区间,每个区间对应一个连续的dfs序
当删掉u时,根据u和x的关系,有两种情况:
1. lca(u,x)=u,即u是x的祖先,那么由于u不可达了,
记x到u的路径上u的直连儿子是v,那么相当于只保留下来v这棵子树内可以到达,
也就是ban掉[0,in[v])、[out[v],n)
2. u和x没有祖先关系,那么由于u不可达,
所以u的这棵子树不可达了,
ban掉[in[u],out[u]]
根据k个点,获取到k个ban掉的区间时,
根据上面提到的dfs序的性质,
dfs序只会存在区间嵌套(li<lj<rj<ri)的情况,
不会存在两个dfs区间相交一部分(li<lj<ri<rj)的情况
按左端点增序,左端点相同右端点降序排序,遍历,
手动去除掉被套在内层的区间,只保留外层的区间
这样得到的若干个区间,就是互不相交的若干个要ban掉的dfs序区间,
其补集,就是合法的区间,均与x连通,
利用上文提到的树的直径的性质,统一merge合法区间的直径,
具体来说,线段树上每一个区间维护这个dfs序区间的直径的两个点,
求合法区间[l,r]的直径时,先在线段树上做一个merge,
再对若干个合法区间做一个merge,
再和询问点x做一个merge,这样得到了x连通块的直径的两个端点
x能到的最远点一定是直径两个点中的一个,分别询问距离取max即可
代码
#include <bits/stdc++.h>
#define sz(x) ((int)(x.size()))
#define all(x) x.begin(), x.end()
#define pb push_back
#define eb emplace_back
const int MX = 2e5 +10, int_max = 0x3f3f3f3f;
using namespace std;
//lca template start
vector<int> dep, sz, par, head, tin, tout, tour;
vector<vector<int>> adj;
int n, ind, q;
void dfs(int x, int p){
sz[x] = 1;
dep[x] = dep[p] + 1;
par[x] = p;
for(auto &i : adj[x]){
if(i == p) continue;
dfs(i, x);
sz[x] += sz[i];
if(adj[x][0] == p || sz[i] > sz[adj[x][0]]) swap(adj[x][0], i);
}
if(p != 0) adj[x].erase(find(all(adj[x]), p));
}
void dfs2(int x, int p){
tour[ind] = x;
tin[x] = ind++;
for(auto &i : adj[x]){
if(i == p) continue;
head[i] = (i == adj[x][0] ? head[x] : i);
dfs2(i, x);
}
tout[x] = ind;
}
int k_up(int u, int k){
if(dep[u] <= k) return -1;
while(k > dep[u] - dep[head[u]]){
k -= dep[u] - dep[head[u]] + 1;
u = par[head[u]];
}
return tour[tin[u] - k];
}
int lca(int a, int b){
while(head[a] != head[b]){
if(dep[head[a]] > dep[head[b]]) swap(a, b);
b = par[head[b]];
}
if(dep[a] > dep[b]) swap(a, b);
return a;
}
int dist(int a, int b){
return dep[a] + dep[b] - 2*dep[lca(a, b)];
}
//lca template end
//segtree template start
#define ff first
#define ss second
int dist(pair<int, int> a){
return dist(a.ff, a.ss);
}
pair<int, int> merge(pair<int, int> a, pair<int, int> b){
auto p = max(pair(dist(a), a), pair(dist(b), b));
for(auto x : {a.ff, a.ss}){
for(auto y : {b.ff, b.ss}){
if(x == 0 || y == 0) continue;
p = max(p, pair(dist(pair(x, y)), pair(x, y)));
}
}
return p.ss;
}
pair<int, int> mx[MX*4];
#define LC(k) (2*k)
#define RC(k) (2*k +1)
void update(int p, int k, int L, int R){
if(L + 1 == R){
mx[k] = {tour[p], tour[p]};
return ;
}
int mid = (L + R)/2;
if(p < mid) update(p, LC(k), L, mid);
else update(p, RC(k), mid, R);
mx[k] = merge(mx[LC(k)], mx[RC(k)]);
}
void query(int qL, int qR, vector<pair<int, int>>& ret, int k, int L, int R){
if(qR <= L || R <= qL) return ;
if(qL <= L && R <= qR){
ret.push_back(mx[k]);
return ;
}
int mid = (L + R)/2;
query(qL, qR, ret, LC(k), L, mid);
query(qL, qR, ret, RC(k), mid, R);
}
//segtree template end
int query(vector<int> arr, int x){
vector<pair<int, int>> banned, ret;
for(int u : arr){
if(lca(u, x) == u){
u = k_up(x, dep[x] - dep[u] - 1);
banned.push_back({0, tin[u]});
banned.push_back({tout[u], n});
}else{
banned.push_back({tin[u], tout[u]});
}
}
sort(all(banned), [&](pair<int, int> a, pair<int, int> b){
return (a.ff < b.ff) || (a.ff == b.ff && a.ss > b.ss);
});
vector<pair<int, int>> tbanned; //remove nested intervals
int mx = 0;
for(auto [a, b] : banned){
if(b <= mx) continue;
else if(a != b){
tbanned.pb({a, b});
mx = b;
}
}
banned = tbanned;
int tim = 0;
for(auto [a, b] : banned){
if(tim < a)
query(tim, a, ret, 1, 0, n);
tim = b;
}
if(tim < n)
query(tim, n, ret, 1, 0, n);
pair<int, int> dia = pair(x, x);
for(auto p : ret) dia = merge(dia, p);
int ans = max(dist(x, dia.ff), dist(x, dia.ss));
return ans;
}
void solve(){
cin >> n >> q;
dep = sz = par = head = tin = tout = tour = vector<int>(n+1, 0);
adj = vector<vector<int>>(n+1);
for(int i = 1; i<n; i++){
int a, b;
cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
dfs(1, 0);
head[1] = 1;
dfs2(1, 0);
for(int i = 1; i<=n; i++){
update(tin[i], 1, 0, n);
}
for(int i = 1; i<=q; i++){
int x, k;
cin >> x >> k;
vector<int> arr(k);
for(int& y : arr) cin >> y;
cout << query(arr, x) << "\n";
}
}
signed main(){
cin.tie(0) -> sync_with_stdio(0);
int T = 1;
//cin >> T;
for(int i = 1; i<=T; i++){
//cout << "Case #" << i << ": ";
solve();
}
return 0;
}