【HDU 5886】Tower Defence(树的直径+树形DP)
题目大意:
带边权的树,随机删掉一条边,保留剩下两棵树中较大的中线。
求删除一条边后中线大小的期望*(n-1)
说白了就是统计删除每条边的情况下两棵树较大的中线,求个和。
对于原树,两次dfs可以找出来中线,同时可以标记中线上的点。
那么对于任何一条边
u−v
,如果两个点有一个不在中线上,删除后对结果没有影响,较大的中线仍为原树的中线。
如果两个点都在中线上,呢么就要对于两个新生成的树分别求中线取最大。
对于原树中线两端为 a,b ,两次树DP分别求出a为根时,所有子树的中线以及b为根时所有子树的中线。
这样查询的时候 判断一下u,v分别相对a,b的距离(也就是深度)
搞清楚切割后两颗树谁是相对于a的子树,谁是相对于b的子树。
取个最大即可。
代码如下:
#include <iostream>
#include <cmath>
#include <vector>
#include <cstdlib>
#include <cstdio>
#include <climits>
#include <ctime>
#include <cstring>
#include <queue>
#include <stack>
#include <list>
#include <algorithm>
#include <map>
#include <set>
#define LL long long
#define Pr pair<int,int>
#define fread(ch) freopen(ch,"r",stdin)
#define fwrite(ch) freopen(ch,"w",stdout)
using namespace std;
const int INF = 0x3f3f3f3f;
const int msz = 10000;
const int mod = 1e9+7;
const double eps = 1e-8;
const int maxn = 100010;
struct Edge
{
int v,w,next;
} eg[maxn<<1];
bool vis[maxn],on[maxn];
int dis[maxn],cost[2][3][maxn],depth[2][maxn];
int head[maxn];
int tp;
int n;
void Add(int u,int v,int w)
{
eg[tp].v = v;
eg[tp].w = w;
eg[tp].next = head[u];
head[u] = tp++;
}
void dfs1(int u,int pre)
{
int v,w;
for(int i = head[u]; i != -1; i = eg[i].next)
{
v = eg[i].v;
w = eg[i].w;
if(v == pre) continue;
dis[v] = dis[u] + w;
dfs1(v,u);
}
}
int st,en;
bool dfs2(int u,int pre)
{
int v;
on[u] = u==en;
for(int i = head[u]; i != -1; i = eg[i].next)
{
v = eg[i].v;
if(v == pre) continue;
if(dfs2(v,u)) on[u] = 1;
}
return on[u];
}
void dfs3(int pos,int u,int pre,int dep)
{
depth[pos][u] = dep;
cost[pos][0][u] = cost[pos][1][u] = cost[pos][2][u] = 0;
int v,w;
for(int i = head[u]; i != -1; i = eg[i].next)
{
v = eg[i].v;
w = eg[i].w;
if(v == pre) continue;
dfs3(pos,v,u,dep+1);
int tmp = cost[pos][1][v]+w;
if(tmp > cost[pos][1][u])
{
cost[pos][2][u] = cost[pos][1][u];
cost[pos][1][u] = tmp;
}
else if(tmp > cost[pos][2][u]) cost[pos][2][u] = tmp;
cost[pos][0][u] = max(cost[pos][0][u],cost[pos][0][v]);
}
cost[pos][0][u] = max(cost[pos][0][u],cost[pos][1][u]+cost[pos][2][u]);
}
void init()
{
dis[1] = 0;
dfs1(1,1);
st = 1;
for(int i = 1; i <= n; ++i)
if(dis[i] > dis[st]) st = i;
dis[st] = 0;
memset(on,0,sizeof(on));
dfs1(st,st);
en = 1;
for(int i = 1; i <= n; ++i)
if(dis[i] > dis[en]) en = i;
dfs2(st,st);
dfs3(0,st,st,0);
dfs3(1,en,en,0);
}
int cal(int u,int v)
{
if(!on[u] || !on[v]) return dis[en];
if(depth[0][u] > depth[0][v])
{
return max(cost[0][0][u],cost[1][0][v]);
}
return max(cost[0][0][v],cost[1][0][u]);
}
int main()
{
//fread("");
//fwrite("");
int t,u,v,w;
scanf("%d",&t);
while(t--)
{
memset(head,-1,sizeof(head));
tp = 0;
scanf("%d",&n);
for(int i = 1; i < n; ++i)
{
scanf("%d%d%d",&u,&v,&w);
Add(u,v,w);
Add(v,u,w);
}
init();
LL ans = 0;
for(u = 1; u <= n; ++u)
for(int i = head[u]; i != -1; i = eg[i].next)
{
ans += cal(u,eg[i].v);
}
printf("%lld\n",ans/2);
}
return 0;
}