题目链接:点击打开链接
题意:
给定n个点的树,任意拆掉一条边,得到2个子树,再用刚拆掉的边把这两个子树连起来。
得到新的树,这个树的权值为任意两个点间的距离和。
使得新的树权值最小。输出这个权值。
枚举拆掉的边(u,v)
得到2个以u为根的子树和以v为根的子树
计算每条边对答案的贡献,拆掉的边贡献就是siz[u]*siz[v]*edge[u,v].dis
剩下的就是计算如何连接2个子树使得权值和最小。
对于子树中的一条边x, y,若已知两端的节点数为i,j,则这条边对答案的贡献就是 i*j*edge[x,y].dis
新建的边实际上只有连接x点方向或者y点方向。
所以就能得到dp数组
dp[u][0]表示新建边连接u的子树时u这个子树的贡献。
dp[u][1]表示新建边连接的不是u的子树时的贡献。
java还是tle。。。
#include<iostream>
#include<stdio.h>
#include<string.h>
#include<queue>
#include<math.h>
template <class T>
inline bool rd(T &ret) {
char c; int sgn;
if (c = getchar(), c == EOF) return 0;
while (c != '-' && (c<'0' || c>'9')) c = getchar();
sgn = (c == '-') ? -1 : 1;
ret = (c == '-') ? 0 : (c - '0');
while (c = getchar(), c >= '0'&&c <= '9') ret = ret * 10 + (c - '0');
ret *= sgn;
return 1;
}
template <class T>
inline void pt(T x) {
if (x <0) {
putchar('-');
x = -x;
}
if (x>9) pt(x / 10);
putchar(x % 10 + '0');
}
using namespace std;
typedef long long ll;
const int N = 5050;
const int M = N * 2;
const ll inf64 = 1e18;
struct Edge{
int from, to, dis, nex;
}edge[M << 1];
int head[N], edgenum;
void init_edge(){ for (int i = 0; i < N; i++)head[i] = -1; edgenum = 0; }
void add(int u, int v, int dis){
Edge E = { u, v, dis, head[u] };
edge[edgenum] = E;
head[u] = edgenum++;
}
ll dp[N][2];
int siz[N];
int n;
void justgo(int u, int fa){
siz[u] = 1;
for (int i = head[u]; i != -1; i = edge[i].nex){
int v = edge[i].to; if (v == fa)continue;
justgo(v, u);
siz[u] += siz[v];
}
}
void dfs(int u, int fa, int root){
dp[u][0] = inf64;
dp[u][1] = 0;
for (int i = head[u]; i != -1; i = edge[i].nex){
int v = edge[i].to; if (v == fa)continue;
dfs(v, u, root);
dp[u][1] += dp[v][1] + (ll)edge[i].dis*siz[v] * (n - siz[v]);
}
for (int i = head[u]; i != -1; i = edge[i].nex){
int v = edge[i].to; if (v == fa)continue;
dp[u][0] = min(dp[u][0], dp[u][1] - (ll)edge[i].dis*siz[v] * (n - siz[v]) - dp[v][1] + min(dp[v][1], dp[v][0]) + (ll)edge[i].dis*(siz[root] - siz[v])*(n - siz[root] + siz[v]));
}
}
int main(){
init_edge();
rd(n);
for (int i = 1, u, v, d; i < n; i++){
rd(u); rd(v); rd(d);
add(u, v, d); add(v, u, d);
}
ll ans = inf64;
for (int i = 0, u, v; i < edgenum; i += 2){
u = edge[i].from; v = edge[i].to;
justgo(u, v); justgo(v, u);
dfs(u, v, u);
dfs(v, u, v);
ans = min(ans, min(dp[u][0], dp[u][1]) +
min(dp[v][0], dp[v][1]) + (ll)edge[i].dis*siz[u] * siz[v]);
}
pt(ans);
return 0;
}