B - Red Black Tree
二分加 lca
关键问题在于判断过程 , 当我们选中一个答案时,我们先看看集合内有哪些数比它大,并看看能不能把这些数进行修改。
下列情况无法取该答案:
1.取到公共祖先的时候更新后的距离 仍比 选中的答案来的要大。(前提:节点与公共祖先之间没有其他红色节点)
2, 节点与祖先之间已经有红色节点了(或者这个祖先本身就是红色节点),这个时候无法更新距离,因为题目要求每个点的距离是最近红色祖先的距离,这个时候便无法更新。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> PII;
const int N = 2e6+10, mod = 1e9+7 , F = mod+1>>1, M = 1e8+10;
int h[N], ne[N<<1], e[N<<1], edge[N<<1], idx;
int fa[N][20], deep[N], a[N];
ll mi[N], de[N];
bool c[N];
int n, m, qq,t;
void add (int a,int b,int c)
{
e[idx]=b;
ne[idx]=h[a];
edge[idx] = c;
h[a]=idx++;
}
void dfs(int u, int f)
{
deep[u] = deep[f] + 1;
for(int i = 1;i < 20;i ++) fa[u][i] = fa[fa[u][i-1]][i-1];
for(int i = h[u];~i ;i = ne[i])
{
if(e[i] == f)continue;
fa[e[i]][0] = u;
mi[e[i]] = c[e[i]] ? 0 : mi[u] + edge [i];
de[e[i]] = de[u] + edge[i];
dfs(e[i], u);
}
}
int lca(int u, int v)
{
if(deep[u]<deep[v])swap(u, v);
int x = deep[u] - deep[v];
for(int i = 0;i < 20;i ++)if(x>>i&1)u = fa[u][i];
if(u == v)return v;
for(int i = 19;i >= 0;i --)
if(fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
bool cmp(int x,int y){return mi[x] < mi[y];}
bool in(int u, int v){return !c[u] && mi[v]-mi[u] == de[v] - de[u];}
bool check(ll mid,int num)
{
int now = 0;
for(int i = 1;i <= num;i ++)
if(mi[a[i]] > mid) now = now ? lca(now, a[i]) : a[i];
for(int i = 1;i <= num;i ++)
if(mi[a[i]] > mid)
if(!in(now, a[i]) || mi[a[i]] - mi[now] > mid)
return 0;
return 1;
}
int main()
{
std::ios::sync_with_stdio(false);
cin >> t;
while(t--)
{
cin >> n >> m >> qq;
for(int i = 1;i <= n;i ++) h[i] = -1; idx = 0;
for(int i = 1;i <= m;i ++)
{
int x;
cin >> x;
c[x] = 1;
}
for(int i = 2;i <= n;i ++)
{
int u, v, w;
cin >> u >> v >> w;
add(u, v, w); add(v, u, w);
}
dfs(1, 0);
// cout <<"azazaz" << mi[4] << endl;
for(int i = 1 ; i <= qq ; i++)
{
int k;
cin >> k;
for(int j = 1 ; j <= k ; j++) cin >> a[j];
sort(a + 1,a + k + 1,cmp);
ll l = 0 , r = 1e17;
while(l < r)
{
ll mid = (l + r) >> 1;
if(check(mid,k)) r = mid;
else l = mid + 1;
}
printf("%lld\n", l);
}
for(int i = 1; i <= n ; i++) c[i] = 0;
}
return 0;
}