题目:https://vjudge.net/problem/Gym-102040F
题意:求树上k条路径经过的公共点的个数
思路:最后的公共点一定是一段连续区间,经过的次数一定是 k 次。树链剖分,在线段树上查询数为k的个数,记录区间最大值和最小值,当最大值和最小值相等且都等于k时,整个区间有贡献。由于k比较小,没必要每次清空整棵树,记录 k 条路径,之后减去就行。
代码:
#include <bits/stdc++.h>
#define ls rt<<1
#define rs rt<<1|1
using namespace std;
const int maxn = 1e4+5;
vector<int>E[maxn];
int Index, d[maxn], top[maxn], son[maxn], f[maxn], siz[maxn], id[maxn], rk[maxn];
void dfs1(int u, int fa){
siz[u] = 1; son[u] = 0;
for(auto it : E[u]){
int v = it;
if(v == fa) continue;
d[v] = d[u] + 1; f[v] = u;
dfs1(v, u);
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int rt){
id[u] = ++Index; rk[Index] = u; top[u] = rt;
if(son[u]) dfs2(son[u], rt);
for(auto it : E[u]){
if(it == f[u] || it == son[u]) continue;
dfs2(it, it);
}
}
int n, q, k, Max[maxn<<2], Min[maxn<<2], tag[maxn<<2];
void pushup(int rt){
Max[rt] = max(Max[ls], Max[rs]);
Min[rt] = min(Min[ls], Min[rs]);
}
void pushdown(int rt){
if(tag[rt]){
tag[ls] += tag[rt]; tag[rs] += tag[rt];
Max[ls] += tag[rt]; Max[rs] += tag[rt];
Min[ls] += tag[rt]; Min[rs] += tag[rt];
tag[rt] = 0;
}
}
void update(int L, int R, int val, int f=1, int rt=1, int l=1, int r=n){
if(L <= l && R >= r){
Min[rt] += val; Max[rt] += val;
tag[rt] += val;
return ;
} int mid = l+r >> 1; pushdown(rt);
if(L <= mid) update(L, R, val, f, ls, l, mid);
if(R > mid) update(L, R, val, f, rs, mid+1, r);
pushup(rt);
}
int query(int L, int R, int rt=1, int l=1, int r=n){
if(L <= l && R >= r && Min[rt] == k && Max[rt] == k) return r-l+1;
if(R<l || L>r || Max[rt] != k) return 0;
pushdown(rt); int mid = l+r >> 1;
return query(L, R, ls, l, mid)+query(L, R, rs, mid+1, r);
}
void Upd(int u, int v, int val){
while(top[u] != top[v]){
if(d[top[u]] < d[top[v]]) swap(u, v);
update(id[top[u]], id[u], val);
u = f[top[u]];
}
if(d[u] < d[v]) swap(u, v);
update(id[v], id[u], val);
}
int Qry(int u, int v){
int res = 0;
while(top[u] != top[v]){
if(d[top[u]] < d[top[v]]) swap(u, v);
res += query(id[top[u]], id[u]);
u = f[top[u]];
}
if(d[u] < d[v]) swap(u, v);
res += query(id[v], id[u]);
return res;
}
int t, Case;
int main()
{
scanf("%d", &t);
while(t--){
scanf("%d", &n);
Index = 0;
for(int i=1; i<=n; i++) E[i].clear();
int u, v;
for(int i=1; i<n; i++) {
scanf("%d%d", &u, &v);
E[u].push_back(v); E[v].push_back(u);
}
d[1] = 1; f[1] = 1; dfs1(1, 0); dfs2(1, 1);
printf("Case %d:\n", ++Case);
scanf("%d", &q);
pair<int, int> p[55];
while(q--){
scanf("%d", &k);
update(1, n, 0);
for(int i=1; i<=k; i++){
scanf("%d%d", &u, &v);
p[i] = make_pair(u, v);
Upd(u, v, 1);
if(i==k) printf("%d\n", Qry(u, v));
}
for(int i=1; i<=k; i++) Upd(p[i].first, p[i].second, -1);
}
}
}