容易看出来是一个树形dp,并且有一个非常显然的状态转移方程:
,其中v是树上从u到根节点路径上的点。
但是显然这样的时间复杂度在树退化成链的时候会达到,需要想办法来进行优化。尝试进行变形:
如果状态v和w都可以转移到状态u,那么在这种情况下,从状态v转移会更优:
我们发现上式变成了一个斜率的形式。考虑将(dis[i], f[i])的点绘制出来,如果出现了下面的情况:
,那么通过枚举各种情况,我们可以分析出来 j 处必不可能是较优的点。
也就是说,有可能作为最优解进行转移的状态,它们的点必然是在一个下凸壳上的。
每次在得到一个新的状态的时候,由于dis[]的单调性,它的位置必然是在这个半凸壳的右端处。由于dis[]的单调性,f[]也是满足单调递增,这样就可以用一个单调队列来维护半凸壳上的点。
对于每个新的状态u,具体的维护方法为:
1. 检查队头的两个元素q[l]和q[l+1],通过上面的斜率检查,如果q[l+1]比q[l]更优,那么就把q[l]出队。
2. 直接取队头的元素为目标状态,进行状态转移,计算出f[u]。
3. 将u插入队尾。插入之前需要检查三个状态q[r-1], q[r], u是否满足斜率单调递增,若不满足则将q[r]出队。
这样就将整个DP的时间复杂度优化到了。
需要注意的是,由于每个节点可能有多个子节点,因此每次转移之后要将队尾恢复为原来的元素。
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <queue>
#include <map>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 100050;
const ll INF = (1LL << 62) - 1;
const double eps = 1e-8;
const ll mod = 2147493647;
const double pi = acos(-1.0);
int t, n, a, b, l, r;
int no, head[maxn], q[maxn];
ll dp[maxn], dis[maxn], ans, m, x;
struct node
{
int to, nxt;
ll w;
}e[maxn << 1];
void add(int a, int b, ll x)
{
e[no].to = b;
e[no].nxt = head[a];
e[no].w = x;
head[a] = no++;
}
ll gety(int u, int v)
{
return dp[u] + dis[u]*dis[u] - dp[v] - dis[v]*dis[v];
}
ll getx(int u, int v) {return dis[u] - dis[v];}
ll getdp(int u, int v) {return dp[v] + m + (dis[u] - dis[v])*(dis[u] - dis[v]);}
void pre(int u, int fa)
{
for(int i = head[u];i != -1;i = e[i].nxt)
{
int v = e[i].to;
if(v == fa) continue;
dis[v] = dis[u] + e[i].w;
pre(v, u);
}
}
void dfs(int u, int fa, int l, int r)
{
int pre = -1;
while(l < r && gety(q[l+1], q[l]) <= 2*dis[u]*getx(q[l+1], q[l])) l++;
dp[u] = min(dp[u], getdp(u, q[l]));
while(l < r && getx(u, q[r])*gety(q[r], q[r-1]) >= getx(q[r], q[r-1])*gety(u, q[r])) r--;
pre = q[++r], q[r] = u;
ans = max(ans, dp[u]);
for(int i = head[u];i != -1;i = e[i].nxt)
{
int v = e[i].to;
if(v == fa) continue;
dfs(v, u, l, r);
}
if(pre != -1) q[r] = pre;
}
void init()
{
no = r = 0, l = 1;
memset(dis, 0, sizeof(dis));
memset(head, -1, sizeof(head));
ans = dp[1] = q[0] = 0;
}
int main()
{
scanf("%d", &t);
while(t--)
{
scanf("%d%lld", &n, &m);
init();
for(int i = 1;i < n;i++)
{
scanf("%d%d%lld", &a, &b, &x);
add(a, b, x), add(b, a, x);
}
pre(1, -1);
for(int i = 1;i <= n;i++)
dp[i] = dis[i]*dis[i];
dfs(1, -1, 1, 0);
printf("%lld\n", ans);
}
return 0;
}