题解:
将u -> v的有向路径拆成 u -> lca 和 lca -> v的有向路径,那么答案一定是三种情况之一:
- u -> lca路径上买,u -> lca路径上卖
- u -> lca路径上买,lca -> v路径上卖
- lca -> v路径上买,lca -> v路径上卖
利用树链剖分将u -> lca 和 lca -> v的路径拆成一段一段的链,现在考虑怎样合并两条链
假设从上到下的路径 a -> b -> c,那么a -> c有向路径的答案就一定是 max{a -> b的答案,b -> c的答案,b -> c的最大值减去a -> b路径上最小值},也就是:
a
n
s
a
−
>
c
=
m
a
x
(
a
n
s
a
−
>
b
,
a
n
s
b
−
>
c
,
m
a
x
v
a
l
b
−
>
c
−
m
i
n
v
a
l
a
−
>
b
)
ans_{a->c} = max(ans_{a->b}, ans_{b->c}, maxval_{b->c} - minval_{a->b})
ansa−>c=max(ansa−>b,ansb−>c,maxvalb−>c−minvala−>b)
从下到上的有向路径同理
因此我们树剖后需要在线段树中维护 5 个值:
- lans 从上到下的答案
- rans 从下到上的答案
- minval 路径中的最小值
- maxval 路径中的最大值
- add 区间加
写好pushup 和 pushdown了以后,其他应该也就没什么大问题了(多组数据,注意清空
代码:
/*
* @Author : Nightmare
*/
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define ld long double
#define ls 2 * rt
#define rs 2 * rt + 1
#define PII pair<int,int>
#define PDD pair<double, double>
#define gcd(a,b) __gcd(a,b)
#define lowbit(x) (x & (-x))
const int N = 5e4 + 5;
const int M = 2e5 + 5;
const int mod = 1e9 + 7;
int n, m, a[N], head[N], to[M], nxt[M], tot;
int fa[N], sz[N], dep[N], son[N], top[N], id[N], pos[N], dfn;
struct node{ int mi, mx, add, lans, rans; }t[N << 2];
void add_edge(int a, int b){
to[++tot] = b; nxt[tot] = head[a]; head[a] = tot;
to[++tot] = a; nxt[tot] = head[b]; head[b] = tot;
}
void dfs1(int u, int f){
fa[u] = f; dep[u] = dep[f] + 1; sz[u] = 1;
for(int i = head[u] ; i ; i = nxt[i]){
if(to[i] == f) continue;
dfs1(to[i], u);
sz[u] += sz[to[i]];
if(sz[to[i]] > sz[son[u]]) son[u] = to[i];
}
}
void dfs2(int u, int f){
top[u] = f; id[u] = ++dfn; pos[dfn]=u;
if(son[u]) dfs2(son[u], f);
for(int i = head[u] ; i ; i = nxt[i]) if(to[i] != son[u] && to[i] != fa[u]) dfs2(to[i], to[i]);
}
#define mid ((l + r) >> 1)
void pushup(int rt){
t[rt].lans = max({t[ls].lans, t[rs].lans, t[rs].mx - t[ls].mi});
// 从上到下的答案,我们用下面的mx - 上面的mi,下面的标号较大,因此用t[rs].mx - t[ls].mi
t[rt].rans = max({t[ls].rans, t[rs].rans, t[ls].mx - t[rs].mi});
// 从下到上的答案,我们用上面的mx - 下面的mi,上面的标号较小,因此用t[ls].mx - t[rs].mi
t[rt].mi = min(t[ls].mi, t[rs].mi);
t[rt].mx = max(t[ls].mx, t[rs].mx);
}
void pushdown(int rt){
if(t[rt].add){
t[ls].mi += t[rt].add; t[ls].mx += t[rt].add; t[ls].add += t[rt].add;
t[rs].mi += t[rt].add; t[rs].mx += t[rt].add; t[rs].add += t[rt].add;
t[rt].add = 0;
}
}
void build(int rt, int l, int r){
t[rt].add = t[rt].lans = t[rt].rans = 0;
if(l == r){ t[rt].mi = t[rt].mx = a[pos[l]]; return ; }
build(ls, l, mid); build(rs, mid + 1, r);
pushup(rt);
}
void change(int rt, int l, int r, int ql, int qr, int qv){
if(ql == l && r == qr){ t[rt].mx += qv; t[rt].mi += qv; t[rt].add += qv; return ; }
pushdown(rt);
if(qr <= mid) change(ls, l, mid, ql, qr, qv);
else if(ql > mid) change(rs, mid + 1, r, ql, qr, qv);
else change(ls, l, mid, ql, mid, qv), change(rs, mid + 1, r, mid + 1, qr, qv);
pushup(rt);
}
node query(int rt, int l, int r, int ql, int qr){
if(ql == l && r == qr) return t[rt];
pushdown(rt);
if(qr <= mid) return query(ls, l, mid, ql, qr);
else if(ql > mid) return query(rs, mid + 1, r, ql, qr);
else{
node lx = query(ls, l, mid, ql, mid), rx = query(rs, mid + 1, r, mid + 1, qr), res;
res.lans = max({lx.lans, rx.lans, rx.mx - lx.mi});
res.rans = max({lx.rans, rx.rans, lx.mx - rx.mi});
res.mi = min(lx.mi, rx.mi);
res.mx = max(lx.mx, rx.mx);
return res;
}
}
#undef mid
int query_path(int u, int v, int w){
// uans 维护从 u 向上到 lca 的路径的答案
// dans 维护从 lca 向下到 v 的路径的答案
node uans = {INT_MAX, 0, 0, 0, 0}, dans = {INT_MAX, 0, 0, 0, 0};
while(top[u] != top[v]){
if(dep[top[u]] > dep[top[v]]){
node cur = query(1, 1, n, id[top[u]], id[u]);
change(1, 1, n, id[top[u]], id[u], w);
uans.rans = max({uans.rans, cur.rans, cur.mx - uans.mi});
uans.mi = min(uans.mi, cur.mi);
uans.mx = max(uans.mx, cur.mx);
u = fa[top[u]];
}else{
node cur = query(1, 1, n, id[top[v]], id[v]);
change(1, 1, n, id[top[v]], id[v], w);
dans.lans = max({dans.lans, cur.lans, dans.mx - cur.mi});
dans.mi = min(dans.mi, cur.mi);
dans.mx = max(dans.mx, cur.mx);
v = fa[top[v]];
}
}
if(dep[u] > dep[v]){
node cur = query(1, 1, n, id[v], id[u]);
change(1, 1, n, id[v], id[u], w);
uans.rans = max({uans.rans, cur.rans, cur.mx - uans.mi});
uans.mi = min(uans.mi, cur.mi);
uans.mx = max(uans.mx, cur.mx);
}else{
node cur = query(1, 1, n, id[u], id[v]);
change(1, 1, n, id[u], id[v], w);
dans.lans = max({dans.lans, cur.lans, dans.mx - cur.mi});
dans.mi = min(dans.mi, cur.mi);
dans.mx = max(dans.mx, cur.mx);
}
return max({uans.rans, dans.lans, dans.mx - uans.mi}); // 三种情况取max
}
void solve(){
scanf("%d", &n); memset(head, 0, sizeof(head)); tot = dfn = 0;
for(int i = 1 ; i <= n ; i ++) scanf("%d", &a[i]), sz[i] = son[i] = 0;
for(int i = 1, u, v ; i < n ; i ++) scanf("%d %d", &u, &v), add_edge(u, v);
dfs1(1, 0); dfs2(1, 1); build(1, 1, n);
scanf("%d", &m);
for(int i = 1 ; i <= m ; i ++){
int u, v, w; scanf("%d %d %d", &u, &v, &w);
printf("%d\n", query_path(u, v, w));
}
}
signed main(){
#ifndef ONLINE_JUDGE
freopen("E:\\in.txt", "r", stdin);
#endif
int T; cin >> T; while(T --) solve();
#ifndef ONLINE_JUDGE
cerr << "Time elapsed: " << 1.0 * clock() / CLOCKS_PER_SEC << " s.\n";
#endif
return 0;
}