题目大意:
给你一棵n个节点的树, 带有边权c[i], 定义路径{
e1,e2…ek
e
1
,
e
2
…
e
k
}的费用是
(e1−e2)2+(e2−e3)2+⋯+(ek−1−ek)2
(
e
1
−
e
2
)
2
+
(
e
2
−
e
3
)
2
+
⋯
+
(
e
k
−
1
−
e
k
)
2
。 求每个节点距自身最远点的距离。
(n≤105,ci≤105,∑n≤106)
(
n
≤
10
5
,
c
i
≤
10
5
,
∑
n
≤
10
6
)
题目思路:
类似于求树的直径做树形dp, 先选定1为根, 用f[i]表示向下走的答案, g[i]表示向上走的答案。 由于费用的更新需要用到两条边, 故扩展一下用f_ch[u][v]表示从u往下走第一步到v的答案, v是u的孩子, 这样复杂度还是O(n)的解决f。
然后考虑向上走的情况g。 考虑已经求出了g[u], 现在要用u来求出他的所有孩子g[v], 对于一个点v来说,
对于第二个max是个经典的dp斜率优化的问题, 将e(u, v)排序后, 维护上凸包+单调队列, 正着做一遍反着做一遍即可。
PS: 关于dp斜率优化
考虑dp:
f[i]=max{f[j]+(e[i]−e[j])2}
f
[
i
]
=
max
{
f
[
j
]
+
(
e
[
i
]
−
e
[
j
]
)
2
}
对与某个转移j, 将式子移项, 分离变量, 只和i有关的部分、 只和j有关的部分、 和i,j均有关的部分。
将 f[i]−e[i]2 f [ i ] − e [ i ] 2 看作截距b, f[j]+e[j]2 f [ j ] + e [ j ] 2 看作y, 2∗e[i] 2 ∗ e [ i ] 看作斜率k, e[j] e [ j ] 看作x。
上式可以看作线性函数b = y - kx。
每个j对应一个坐标(x,y), 一系列的j在图上就是一些点, 对于一个i就是一个询问, 每个i对应一个斜率k, 每个i求一个斜率为k的经过图中某个点的最大截距。
这里是取最大值故维护上凸包(取min则维护下凸包), 在本题中, 考虑将e[i]从小到打排序, 先正过来求一遍, 即每个i都会考虑一遍小于它的j。 维护一个单调队列, 对于询问i, 由于询问的斜率是递增的, 按上凸包顺时针方向看, 相邻点构成的斜率递减, 询问i的取最大值的点满足其向下一个点斜率小于询问i的斜率,向上一个点的斜率大于询问i的斜率, 又考虑到询问i的斜率是递增的来询问的, 凸包上的点也是按x坐标递增来加入的, 故应从单调队列的尾端扫描, 根据斜率的比较关系, 斜率越大的询问取最大值的点越靠前, 如果队尾的上一个点由于队尾, 说明对于一个更大的斜率也会优于队尾的, 故弹出队尾元素。 再将i对应的坐标点加入凸包中, 可以用向量叉积判断凸包走向来决定是否删除队尾元素 。 反向求一遍同理。
Code:
#include <map>
#include <set>
#include <map>
#include <bitset>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define ll long long
#define db double
#define pw(x) ((x) * (x))
#define fi first
#define se second
#define mp(x, y) make_pair(x, y)
using namespace std;
const int N = (int)1e5 + 10;
int n;
int cnt, lst[N], nxt[N * 2], to[N * 2]; ll c[N * 2], pre[N];
map<int, ll> f_ch[N];
map<int, ll> :: iterator it;
ll f[N], g[N];
void add(int u, int v, int w){
nxt[++ cnt] = lst[u]; lst[u] = cnt; to[cnt] = v; c[cnt] = w;
nxt[++ cnt] = lst[v]; lst[v] = cnt; to[cnt] = u; c[cnt] = w;
}
void dfs(int u, int fa){
for (int j = lst[u]; j; j = nxt[j]){
int v = to[j];
if (v == fa) continue;
pre[v] = c[j];
dfs(v, u);
ll &x = f_ch[u][v];
x = 0;
for (it = f_ch[v].begin(); it != f_ch[v].end(); it ++){
x = max(x, it->se + pw(c[j] - pre[it->fi]));
f[u] = max(f[u], x);
}
}
}
pair <ll, int > tmp[N]; int sz;
pair <ll, ll > que[N]; int head, tail;
pair <ll, ll> operator-(pair<ll, ll> a, pair<ll, ll> b){
return mp(a.fi-b.fi, a.se-b.se);
}
ll operator*(pair<ll, ll> a, pair<ll, ll> b){
return a.fi * b.se - a.se * b.fi;
}
ll cross(pair<ll, ll> a, pair<ll, ll> b, pair<ll, ll> c){
return (a - b) * (b - c);
}
ll count(pair<ll, ll > x, ll e){
return x.se-2*e*x.fi+pw(e);
}
void dfs2(int u, int fa){
sz = 0;
for (int j = lst[u]; j; j = nxt[j]){
int v = to[j];
if (v == fa) continue;
tmp[++ sz] = mp(pre[v], v);
}
sort(tmp + 1, tmp + sz + 1);
que[head = tail = 1] = mp(tmp[1].fi, pw(tmp[1].fi) + f_ch[u][tmp[1].se]);
for (int i = 2; i <= sz; i ++){
int v = tmp[i].se; ll e = tmp[i].fi;
while (head < tail && count(que[tail], e) <= count(que[tail - 1], e)) tail --;
g[v] = max(g[v], count(que[tail], e));
pair<ll, ll> p = mp(e, f_ch[u][v] + pw(e));
while (head < tail && cross(p, que[tail], que[tail - 1]) <= 0)
tail --;
que[++ tail] = p;
}
que[head = tail = 1] = mp(tmp[sz].fi, pw(tmp[sz].fi) + f_ch[u][tmp[sz].se]);
for (int i = sz - 1; i >= 1; i --){
int v = tmp[i].se; ll e = tmp[i].fi;
while (head < tail && count(que[tail], e) <= count(que[tail - 1], e)) tail --;
g[v] = max(g[v], count(que[tail], e));
pair<ll, ll> p = mp(e, f_ch[u][v] + pw(e));
while (head < tail && cross(p, que[tail], que[tail - 1]) >= 0)
tail --;
que[++ tail] = p;
}
if (fa){
for (int i = 1; i <= sz; i ++){
int v = tmp[i].se; ll e = tmp[i].fi;
g[v] = max(g[v], g[u] + pw(pre[u] - e));
}
}
for (int j = lst[u]; j; j = nxt[j]){
int v = to[j];
if (v == fa) continue;
dfs2(v, u);
}
}
int getint(){
int ret = 0; char c = getchar();
while (c > '9' || c < '0') c = getchar();
while (c <= '9' && c >= '0'){
ret = ret * 10 + c - '0';
c = getchar();
}
return ret;
}
int main(){
while (scanf("%d", &n) != EOF){
for (int i = 2, u, v, w; i <= n; i ++){
u = getint(), v = getint(), w = getint();
add(u, v, w);
}
dfs(1, 0);
dfs2(1, 0);
for (int i = 1; i <= n; i ++){
printf("%lld\n", max(f[i], g[i]));
lst[i] = 0;
f[i] = g[i] = 0;
f_ch[i].clear();
}
cnt = 0;
}
return 0;
}