答案必然是在直径上选取一段路径
如果不懂请自行考虑证明,需要注意的是证明过程中主要用到树的直径是最长链的性质
先考虑两个点都不在直径上,那距离最大值一定是直径的某一段
如果选择一个点在直径上,会使答案更优
两个点都在直径上的话,考虑直径是最长链的性质
最直接的做法就是 O(n^3) ,绝对的模拟一遍操作
略微有点思考后的做法就是 O(n^2) 贪心,为了使到其他点距离尽可能的小,选择恰好距离小于等于 s 的点对
这都是可以水过弱化版的
就下来考虑更优的做法
不难发现答案具有单调性,可以二分答案, O(n*log(Σedge_i.val))
做法是从直径两端 l 和 r 分别向直径中间跳,跳至距直径距离小于等于 mid 的最远点 x 和 y,这对点作为此次二分的路径
因为直径是最长链,所以 l 到 x 之间的子树到 x 的距离不会大于 dist(l, x) ,对于 r 和 y 同理
只需要再保证 dist(x, y) <= s 即可
二分的范围是 [直径上任意一点到其他点的最短距离,直径长度]
可能不太准但复杂度是可过的
继续深究性质,时间复杂度可优化到 O(n) (反正我是不会
设直径由点 u_1, u_2, u_3, ..., u_t 组成
f[u_i] 表示从点 u_i 出发,不经过直径上的其他节点, 所到的最远点的距离
那么以 u_i, u_j 为端点的路径的答案就是 max(max{f[u_k]}, dist(u_1, u_i), dist(u_j, u_t)) (i <= k <= j)
这样就可以单调队列来求了
还可以注意到的是,利用上边二分答案做法中证明的性质,刚才的式子就可以替换为 max(max{f[u_k]}, dist(u_1, u_i), dist(u_j, u_t)) (1 <= k <= t)
//因为多的那些是不比后边的两项大的
那么就有一项是固定的了
那么我们就可以枚举直径上的每个点,贪心的选到与它距离不超过 s 的最远的点,之后按上述式子判断即可
成功 O(n) 解决
upd:有一些小细节,在直径上取一段的时候,这并不是很好搞,需要把更新当前点最远距离的 from 记录下来,然后就可以愉快的从直径中选出想要的一段了
我什么时候能想到这么神仙的做法就好了啊
代码(二分版):
#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cctype>
#include<cstdio>
#include<queue>
using namespace std;
typedef long long ll;
const int MAXN = 300005;
struct EDGE{
int nxt, to, val;
EDGE(int NXT = 0, int TO = 0, int VAL = 0) {nxt = NXT; to = TO; val = VAL;}
}edge[MAXN << 1];
int n, s, p, q, totedge, maxpt, top ,maxdis;
int frm[MAXN], head[MAXN], stk[MAXN], dst[MAXN];
bool vis[MAXN], ind[MAXN];
inline int rd() {
register int x = 0;
register char c = getchar();
while(!isdigit(c)) c = getchar();
while(isdigit(c)) {
x = x * 10 + (c ^ 48);
c = getchar();
}
return x;
}
inline void add(int x, int y, int v) {
edge[++totedge] = EDGE(head[x], y, v);
head[x] = totedge;
return;
}
inline void bfs(int bgn) {
queue<int> q;
dst[bgn] = 0;
q.push(bgn);
while(!q.empty()) {
int x = q.front(); q.pop();
if(dst[x] > dst[maxpt]) maxpt = x;
for(int i = head[x]; i; i = edge[i].nxt) if(dst[edge[i].to] == 0x3f3f3f3f) {
int y = edge[i].to;
frm[y] = x;
if(ind[y]) dst[y] = dst[x];
else dst[y] = dst[x] + edge[i].val;
q.push(y);
}
}
return;
}
inline bool chk(int mid) {
int l = 1, r = top;
while(stk[1] - stk[l + 1] <= mid && l + 1 <= top) ++l;
while(stk[r - 1] <= mid && r - 1 >= 1 && r - 1 >= l) --r;
return (stk[l] - stk[r] <= s);
}
inline void hfs(int l, int r) {
int ans = 0, mid;
while(l <= r) {
mid = ((l + r) >> 1);
if(chk(mid)) r = mid - 1;
else l = mid + 1;
}
printf("%d\n", l);
return;
}
int main() {
n = rd(); s = rd();
for(int i = 1; i <= n; ++i) dst[i] = 0x3f3f3f3f;
register int xx, yy, vv;
for(int i = 1; i < n; ++i) {
xx = rd(); yy = rd(); vv = rd();
add(xx, yy, vv); add(yy, xx, vv);
}
bfs(1); p = maxpt; maxpt = 0;
for(int i = 1; i <= n; ++i) dst[i] = 0x3f3f3f3f;
bfs(p); q = maxpt;
maxdis = stk[++top] = dst[q];
int tmp = q; ind[tmp] = true;
while(tmp != p) {
stk[++top] = dst[frm[tmp]];
tmp = frm[tmp]; ind[tmp] = true;
}
maxpt = 0;
for(int i = 1; i <= n; ++i) dst[i] = 0x3f3f3f3f;
bfs(tmp);
hfs(dst[maxpt], maxdis);
return 0;
}
代码(O(n)版):
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<cctype>
#include<queue>
using namespace std;
const int MAXN = 500005;
struct EDGE{
int nxt, to, val;
EDGE(int NXT = 0, int TO = 0, int VAL = 0) {nxt = NXT; to = TO; val = VAL;}
}edge[MAXN << 1];
int n, s, totedge, mxpt, p, q, top, lmt, ans = 0x7fffffff;
int head[MAXN], dst[MAXN], frm[MAXN];
pair<int,int> stk[MAXN];
bool ind[MAXN];
queue<int> que;
inline int rd() {
register int x = 0;
register char c = getchar();
while(!isdigit(c)) c = getchar();
while(isdigit(c)) {
x = x * 10 + (c ^ 48);
c = getchar();
}
return x;
}
inline void add(int x, int y, int v) {
edge[++totedge] = EDGE(head[x], y, v);
head[x] = totedge;
return;
}
inline void bfs(int bgn) {
for(int i = 1; i <= n; ++i) dst[i] = 0x7fffffff;
dst[bgn] = 0;
que.push(bgn);
while(!que.empty()) {
int x = que.front(); que.pop(); if(dst[x] > dst[mxpt]) mxpt = x;
for(int i = head[x]; i; i = edge[i].nxt) if(dst[edge[i].to] == 0x7fffffff) {
int y = edge[i].to;
frm[y] = x;
if(ind[y]) dst[y] = dst[x];
else dst[y] = dst[x] + edge[i].val;
que.push(y);
}
}
return;
}
void dfs(int x, int fa) {
dst[x] = 0;
for(int i = head[x]; i; i = edge[i].nxt) if(edge[i].to != fa) {
int y = edge[i].to;
dfs(y, x);
if(ind[y]) dst[x] = max(dst[x], dst[y]);
else dst[x] = max(dst[x], dst[y] + edge[i].val);
}
return;
}
int main() {
n = rd(); s = rd();
register int xx, yy, vv;
for(int i = 1; i < n; ++i) {
xx = rd(); yy = rd(); vv = rd();
add(xx, yy, vv); add(yy, xx, vv);
}
bfs(1); p = mxpt; mxpt = 0;
bfs(p); q = mxpt;
stk[++top] = make_pair(mxpt, dst[mxpt]); ind[mxpt] = true;
while(mxpt != p) {
mxpt = frm[mxpt];
stk[++top] = make_pair(mxpt, dst[mxpt]);
ind[mxpt] = true;
}
dfs(p, 0);
for(int i = 1; i <= top; ++i) lmt = max(lmt, dst[stk[i].first]);
int lptr, rptr = 0, cur;
for(lptr = 1; lptr <= top; ++lptr) {
cur = lmt;
while(stk[lptr].second - stk[rptr + 1].second <= s && rptr + 1 <= top) ++rptr;
cur = max(cur, max(stk[1].second - stk[lptr].second, stk[rptr].second));
ans = min(ans, cur);
}
printf("%d\n", ans);
return 0;
}