思路:
考虑线段树,当前区间的最远点对在子区间的最远点对中取得,所以直接把子区间的最远点对合并上来
c o d e code code
#include<iostream>
#include<cstdio>
#include<vector>
#define reint register int
using namespace std;
const int MAXN = 1e5 + 10;
int n, q, c[10];
int dep[MAXN], f[MAXN][21];
vector<int> b[MAXN];
struct node {
int l, r, flag1, flag2;
int v1, v2, ans;
}tr[MAXN << 2];
inline void dfs(reint x, reint fa) {
dep[x] = dep[fa] + 1;
f[x][0] = fa;
for(reint i = 0; i < b[x].size(); i ++) {
reint y = b[x][i];
if(y != fa) dfs(y, x);
}
}
inline void rmq() {
for(reint j = 1; j <= 20; j ++)
for(reint i = 1; i <= n; i ++) f[i][j] = f[f[i][j - 1]][j - 1];
}
inline int lca(reint x, reint y) {
if(dep[x] > dep[y]) swap(x, y);
reint j = 20, k = dep[y] - dep[x], t = 1 << 20;
while(k) {
if(k >= t) k -= t, y = f[y][j];
j --;
t = 1 << j;
}
if(x == y) return x;
for(reint i = 20; i >= 0; i --)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
inline void mix_(reint k, reint k1, reint k2) {
c[0] = 0;
if(tr[k1].v1 != 0 && tr[k1].flag1) c[++c[0]] = tr[k1].v1;
if(tr[k1].v2 != 0 && tr[k1].flag2) c[++c[0]] = tr[k1].v2;
if(tr[k2].v1 != 0 && tr[k2].flag1) c[++c[0]] = tr[k2].v1;
if(tr[k2].v2 != 0 && tr[k2].flag2) c[++c[0]] = tr[k2].v2;
tr[k].ans = -1;
for(reint i = 1; i <= c[0]; i ++)
for(reint j = 1; j <= i; j ++)
{
reint g = lca(c[i], c[j]);
reint len = dep[c[i]] + dep[c[j]] - dep[g] * 2;
if(len > tr[k].ans) {
tr[k].ans = len, tr[k].v1 = c[i], tr[k].v2 = c[j];
tr[k].flag1 = tr[k].flag2 = 1;
}
}
if(tr[k].ans == -1) tr[k].flag1 = tr[k].flag2 = 0;
}
inline void build(reint k, reint l, reint r) {
tr[k].l = l, tr[k].r = r;
if(l == r) {
tr[k].flag1 = tr[k].flag2 = 1;
tr[k].v1 = tr[k].v2 = l;
tr[k].ans = 0;
return ;
}
reint mid = l + r >> 1;
build(k << 1, l, mid);
build(k << 1 | 1, mid + 1, r);
mix_(k, k << 1, k << 1 | 1);
}
inline void change_(reint k, reint l, reint r, reint x) {
if(l == x && r == x) {
tr[k].flag1 ^= 1;
tr[k].flag2 ^= 1;
tr[k].ans = -1;
return ;
}
reint mid = l + r >> 1;
if(x <= mid) change_(k << 1, l, mid, x);
else change_(k << 1 | 1, mid + 1, r, x);
mix_(k, k << 1, k << 1 | 1);
}
int main() {
scanf("%d%d", &n, &q);
for(reint i = 1; i < n; i ++) {
reint x, y;
scanf("%d%d", &x, &y);
b[x].push_back(y);
b[y].push_back(x);
}
dfs(1, 0);
rmq();
build(1, 1, n);
while(q --) {
string c;
cin>>c;
if(c == "C") {
reint x;
scanf("%d", &x);
change_(1, 1, n, x);
}
else
printf("%d\n", tr[1].ans);
}
return 0;
}