题面
题解
因为在写树剖的时候写崩了,所以这里提供一个
l
c
a
lca
lca 的做法。
虽然有人说
l
c
a
lca
lca 会被卡,但是我觉得用
t
a
r
j
a
n
tarjan
tarjan 写出来复杂度也不假。(复杂度分析在下文)
前言
读完题不难发现,其实我们要清空权值的边就是在清空它之后,能把最长的路径降到最小。最大的最小我们想到了什么?自然而然就能稳一波二分答案。所以就可以我们二分在清空后最长的路径的长度。
二分答案的具体步骤
于是现在我们就需要写一个 c h e c k check check 函数来判断答案是否可行。
- 如果答案合法的话,我们要清空的这条边一定满足在所有路径长度大于我们二分答案的当前值的路径上。
- 所以我们发现,其实对于路径长度小于我们二分答案的当前值的路径,对于当前答案的判定其实是不起作用的。
- 因此我们现在要解决的就是对于第一点的判定。
- 不难发现,在保证要清空的边被所有符合条件的路径经过的情况下,找出它的最大长度。如果在所有符合条件的路径中,最长的路径长度减去我们能清空的最长路径长度小于我们当前二分答案的值的话,那么这个答案就是合法的。
- 那么现在问题就只剩下怎么求能清空的边的最大长度了。
- 因此我们用到了树上差分。我们可以记录一下每个点到它的父亲的这条边被经过的次数,可以证明对于非根节点的的节点,这样的边有且只有一条。关于树上差分的细节这里就不展开了。步骤是对于每个符合条件的路径,在初始化时,将它的端点节点的次数 + 1 +1 +1,将它的 l c a lca lca 的次数 − 1 -1 −1。初始化完后,跑一遍深度优先遍历,将子节点的次数加到父节点的次数中就能得到每条边被经过的次数。
- 设符合条件的路径的总条数为 n u m num num ,那么对于次数等于 n u m num num 的边自然就是被所有满足条件的路径给经过的边。在这些边中找出最长的即可。
- 上述所有要用到的附加信息都能在初始化跑 t a r j a n tarjan tarjan 时跑出来。
关于时间复杂度
t
a
r
j
a
n
tarjan
tarjan 一次花
O
(
n
+
m
)
O(n+m)
O(n+m) 每次
c
h
e
c
k
check
check 先
f
o
r
for
for 一个
O
(
m
)
O(m)
O(m),再
d
f
s
dfs
dfs 一个
O
(
n
)
O(n)
O(n),合起来是
O
(
n
+
m
)
O(n + m)
O(n+m),
二分是一个
O
(
l
o
g
m
a
x
l
e
n
)
O(log\ maxlen)
O(log maxlen),最大是
3
e
8
3e8
3e8。
总时间复杂度是
O
(
(
n
+
m
)
l
o
g
m
a
x
l
e
n
)
O((n+m)log\ maxlen)
O((n+m)log maxlen),而且加上这道题时限是
2
s
2s
2s,即使输入量有点小大,但是应该还是能卡过去的。
代码
#include<cstdio>
#include<vector>
#include<iostream>
#include<cstring>
using std::min;
const int N = 3e5 + 5;
struct edge {
int next,to,w;
}a[N << 1];
struct lenth {
int x,y,len,lca;
}len[N];
int head[N],fa[N],vis[N],dis[N],val[N],n,m,maxn = 0,a_size = 1;
inline void add(int u,int v,int w) {
a[++a_size] = (edge){head[u],v,w};
head[u] = a_size;
a[++a_size] = (edge){head[v],u,w};
head[v] = a_size;
}
std::vector<int> query[N],query_id[N];
inline void add_q(int x,int y,int id) {
query[x].push_back(y),query_id[x].push_back(id);
query[y].push_back(x),query_id[y].push_back(id);
}
int Find(int x) {
return fa[x] == x ? x : fa[x] = Find(fa[x]);
}
void tarjan(int x) {
vis[x] = 1;
for(int i = head[x]; i; i = a[i].next) {
int y = a[i].to;
if(vis[y]) continue;
dis[y] = dis[x] + a[i].w;
tarjan(y);
val[y] = a[i].w;
fa[y] = x;
}
for(int i = 0; i < query[x].size(); i++) {
int y = query[x][i],id = query_id[x][i];
if(vis[y] == 2) {
len[id].lca = Find(y);
len[id].len = min(len[id].len,dis[x] + dis[y] - 2 * dis[len[id].lca]);
if(len[id].len > maxn) maxn = len[id].len;
}
}
vis[x] = 2;
}
int cnt[N],num,res;
void dfs(int x,int pre) {
for(int i = head[x]; i; i = a[i].next) {
int y = a[i].to;
if(y == pre) continue;
dfs(y,x);
cnt[x] += cnt[y];
}
if(cnt[x] == num && val[x] > res)
res = val[x];
}
bool check(int x) {
memset(cnt,0,sizeof(cnt));
num = res = 0;
for(int i = 1; i <= m; i++)
if(len[i].len > x) {
cnt[len[i].x]++;
cnt[len[i].y]++;
cnt[len[i].lca] -= 2;
num++;
}
dfs(1,0);
if(maxn - res > x) return false;
return true;
}
inline int read() {
int x = 0,flag = 1;
char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-')flag = -1;ch = getchar();}
while(ch >='0' && ch <='9'){x = (x << 3) + (x << 1) + ch - 48;ch = getchar();}
return x * flag;
}
int main() {
n = read(),m = read();
for(int i = 1; i < n; i++) {
int u = read(),v = read(),w = read();
add(u,v,w);
}
for(int i = 1; i <= m; i++) {
len[i].x = read();
len[i].y = read();
if(len[i].x == len[i].y) len[i].len = 0;
else {
add_q(len[i].x,len[i].y,i);
len[i].len = 1 << 30;
}
}
for(int i = 1; i <= n; i++) fa[i] = i;
tarjan(1);
int l = 0,r = maxn,ans;
while(l <= r) {
int mid = (l + r) >> 1;
if(check(mid)) {
r = mid - 1;
ans = mid;
}
else l = mid + 1;
}
printf("%d\n",ans);
return 0;
}