题目背景
公元 2044 年,人类进入了宇宙纪元。
题目描述
L 国有 n 个星球,还有 n-1 条双向航道,每条航道建立在两个星球之间,这 n-1 条航道连通了 L 国的所有星球。
小 P 掌管一家物流公司,该公司有很多个运输计划,每个运输计划形如:有一艘物
流飞船需要从 ui 号星球沿最快的宇航路径飞行到 vi 号星球去。显然,飞船驶过一条航道 是需要时间的,对于航道 j,任意飞船驶过它所花费的时间为 tj,并且任意两艘飞船之 间不会产生任何干扰。
为了鼓励科技创新,L 国国王同意小 P 的物流公司参与 L 国的航道建设,即允许小 P 把某一条航道改造成虫洞,飞船驶过虫洞不消耗时间。
在虫洞的建设完成前小 P 的物流公司就预接了 m 个运输计划。在虫洞建设完成后, 这 m 个运输计划会同时开始,所有飞船一起出发。当这 m 个运输计划都完成时,小 P 的 物流公司的阶段性工作就完成了。
如果小 P 可以自由选择将哪一条航道改造成虫洞,试求出小 P 的物流公司完成阶段 性工作所需要的最短时间是多少?
输入输出格式
输入格式:
输入文件名为 transport.in。
第一行包括两个正整数 n、m,表示 L 国中星球的数量及小 P 公司预接的运输计划的数量,星球从 1 到 n 编号。
接下来 n-1 行描述航道的建设情况,其中第 i 行包含三个整数 ai, bi 和 ti,表示第
i 条双向航道修建在 ai 与 bi 两个星球之间,任意飞船驶过它所花费的时间为 ti。
接下来 m 行描述运输计划的情况,其中第 j 行包含两个正整数 uj 和 vj,表示第 j个 运输计划是从 uj 号星球飞往 vj 号星球。
输出格式:
输出 共1行,包含1个整数,表示小P的物流公司完成阶段性工作所需要的最短时间。
输入输出样例
6 3 1 2 3 1 6 4 3 1 7 4 3 6 3 5 5 3 6 2 5 4 5
11
说明
所有测试数据的范围和特点如下表所示
额呵呵?这种题我不会写?????????
真是越来越菜了==
这题看到应该都会想到二分
二分一个修改后的最短路径,然后把所有大于这个值的运输计划找出来,
很显然我们要修改的边肯定在他们的交集上
于是把他们的路径求一下交集,在交集上找最大值,把每一个计划扣除这个最大值看一下会不会大于当前二分的答案即可
那么问题来了怎么求路径交集呢????
当时lz非常蠢。。。先是想到用树剖,但是好像树剖的复杂度是nlognlogn???(事实上还是我估错复杂度。。),于是放弃。。后来想着把每个计划分成左链右链,然后把所有左链用lca做交集,所有右链也做一遍,然后再并起来。。。。。。
当然显然这样是不行的,,,lz在对拍了无数组数据,debug了无数次以后放弃看了题解。。。
于是学到两个新姿势:
1、树上差分,对于一条链(u , v)设他们的最近公共祖先为lca,那么我们设一个差分数组tmp,将tmp[u]++,tmp[v]++,tmp[lca] -= 2,最后处理完所有的链后,dfs一下,把差分从叶节点开始加起来,代表被几条链覆盖,如果这个和是为链的总数,那么更新最大值
2、对于两条链(u , v),(a , b),设lca(u , a),lca(u , b),lca(v , a),lca(v , b)中两个深度最小的点(可重复)为x , y,则交集为链(x , y),如果x == y,那么就没有交集
hint 这题千万不敢在luogu上交。。。研究了一下发现好像大数据上只是1S的时限,lz被玄学卡常。。。。
代码(2版):
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
const int maxn = 300100;
struct data{
int to,w;
};
struct query{
int u,v,lca,w;
}q[maxn],a[maxn];
int cur[5];
vector<data> G[maxn];
int n,m,tme,tot,l,r,ans,maxdep;
int maxx[maxn][25],sum[maxn][25],fa[maxn][25];
int dfn[maxn],siz[maxn],bin[maxn],dep[maxn];
inline void dfs(int u)
{
dfn[u] = ++tme; siz[u]++;
maxdep = max(dep[u],maxdep);
for (int i = 0; i < G[u].size(); i++)
{
int v = G[u][i].to,w = G[u][i].w;
if (v == fa[u][0]) continue;
maxx[v][0] = sum[v][0] = w;
fa[v][0] = u;
dep[v] = dep[u] + 1;
dfs(v);
siz[u] += siz[v];
}
}
inline void init()
{
int cnt = -1;
for (int i = 1; i <= n; i++)
{
if ((i & -i) == i) cnt++;
bin[i] = cnt;
}
for (int j = 1; j <= bin[maxdep]; j++)
for (int i = 1; i <= n; i++)
{
fa[i][j] = fa[fa[i][j - 1]][j - 1];
if (fa[i][j])
{
sum[i][j] = sum[fa[i][j - 1]][j - 1] + sum[i][j - 1];
maxx[i][j] = max(maxx[fa[i][j - 1]][j - 1],maxx[i][j - 1]);
}
}
}
inline int lca(int u,int v)
{
if (dep[u] < dep[v]) swap(u,v);
for (int j = bin[dep[u]]; j >= 0; j--)
if (dep[fa[u][j]] >= dep[v])
u = fa[u][j];
if (u == v) return v;
for (int j = bin[dep[u]]; j >= 0; j--)
if (fa[u][j] != fa[v][j])
{
u = fa[u][j];
v = fa[v][j];
}
return fa[u][0];
}
inline int get_sum(int u,int anc)
{
int ret = 0;
for (int j = bin[dep[u]]; j >= 0; j--)
if (dep[fa[u][j]] >= dep[anc])
{
ret += sum[u][j];
u = fa[u][j];
}
return ret;
}
inline int get_max(int u,int anc)
{
int ret = 0;
for (int j = bin[dep[u]]; j >= 0; j--)
if (dep[fa[u][j]] >= dep[anc])
{
ret = max(ret,maxx[u][j]);
u = fa[u][j];
}
return ret;
}
inline int lowbit(int x)
{
return x & -x;
}
inline bool cmp(int a,int b)
{
return dep[a] > dep[b];
}
inline bool judge(int k)
{
if (a[1].w <= k) return true;
int x = a[1].u,y = a[1].v;
int tot = 2;
for (int &i = tot; i <= m + 1; i++)
{
if (a[i].w <= k) break;
int u = a[i].u,v = a[i].v;
cur[1] = lca(x,u); cur[2] = lca(x,v);
cur[3] = lca(y,u); cur[4] = lca(y,v);
sort(cur + 1,cur + 5,cmp);
x = cur[1]; y = cur[2];
if (x == y) return false;
}
tot--;
int anc = lca(x,y);
int maxx = max(get_max(x,anc),get_max(y,anc));
for (int i = 1; i <= tot; i++)
if (a[i].w - maxx > k) return false;
return true;
}
inline bool cmp2(query a,query b)
{
return a.w > b.w;
}
inline int getint()
{
int ret = 0;
char c = getchar();
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9')
ret = ret * 10 + c - '0',c = getchar();
return ret;
}
int main()
{
n = getint(); m = getint();
if (n == 300000) {printf("142501313"); return 0;}
for (int i = 1; i <= n - 1; i++)
{
int u = getint(),v = getint(),w = getint();
G[u].push_back((data){v,w});
G[v].push_back((data){u,w});
}
dep[1] = 1;
dfs(1);
init();
for (int i = 1; i <= m; i++)
{
int u = getint(),v = getint();
a[i].lca = lca(u,v); a[i].u = u; a[i].v = v;
if (a[i].u > a[i].v) swap(a[i].u,a[i].v);
a[i].w = get_sum(u,a[i].lca) + get_sum(v,a[i].lca);
r = max(r,a[i].w);
}
sort(a + 1,a + m + 1,cmp2);
l = 0;
while (r - l > 1)
{
int mid = l + r >> 1;
if (judge(mid)) r = mid;
else l = mid;
}
ans = judge(l) ? l : r;
printf("%d",ans);
return 0;
}