题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=2196
解题思路:
先思考一下暴力的方法:
建树(双向建边),如果有n个点,每次选择那个点作为根节点,dfs,记录最深深度
这样的话n个点,每次dfs遍历剩下n-1的点,复杂度O(N²),我试了一下,超时了。
于是我去参考了一下大佬的代码,发现了树形DP这样的东西。
树形DP做法:
1.先用链式前向星建边
typedef long long ll;
struct node
{
int to,last;
ll w,sum;
}edge[20000+5];
int head[10000+5];
int id;
void add(int u,int v,ll w){
edge[id].to = v;
edge[id].last = head[u];
edge[id].w = w;
head[u] = id++;
}
edge[i].to记录这条边终点,
edge[i].last记录起点相同的上一条输入的边是edge[某某] (也就是edge[edge[i].last]),
edge[i].w记录权值
edge[i].sum记录从起点开始沿着这条边延伸最远的距离
2.dfs获得根节点(随便找一个,就找1好了)向外扩展的边的sum值
代码:
void dfs(int now,int pre)///pre表示now的父节点,向外拓展,不能回去
{
for (int i=head[now];i!=0;i=edge[i].last){
if (edge[i].to!=pre){
dfs(edge[i].to,now);
dp[now] = max(dp[now],dp[edge[i].to]+edge[i].w);
edge[i].sum = edge[i].w + dp[edge[i].to];
}
}
}
dp[i] = max(dp[子节点] + 两点之间的w值)
edge[i].sum = dp[边终点] + 边权值
举个例子:
以①为根节点
⑤-④-①-②-③,假设中间每条边权值都为1,dfs求出dp[2] = 1,dp[4] =1,dp[1] = 2(其他两个点dp=0)
不过事实上dp[2] =3,不过dfs只负责向外扩展,正确的dp值在下一步求出向内扩展的值后,比较就能得出(有些点第一遍的dp值可能就是最大的)
3.treedp函数求得正确答案
代码:
主函数
for (int i=head[1];i!=0;i=edge[i].last){
treedp(1,edge[i].to);
}
treedp函数
void treedp(int now,int son)
{
ll maxx=0;
for (int i = head[now];i!=0;i=edge[i].last){
if (edge[i].to!=son) maxx = max(maxx,edge[i].sum);
}
for (int i = head[son];i!=0;i=edge[i].last){
if (edge[i].to==now) {edge[i].sum = edge[i].w + maxx; break;}
}
for (int i = head[son];i!=0;i=edge[i].last){
dp[son] = max(dp[son],edge[i].sum);
if (edge[i].to!=now){
treedp(son,edge[i].to);
}
}
}
emm,具体实现自己看吧,怕自己说错误导你们。我还是举个例子说一下自己的理解。
以①为根节点
⑤-④-①-②-③,假设中间每条边权值都为1,还是这个例子吧
第2步不是求出了每条向外拓展的边的sum值嘛,同时dp[2]=1,事实上dp[2]=3
那么怎么求出来呢?第2步我们求出了①->④.sum = 2(向外拓展的),
dp[2]就是①②中间边的权值+2=3 和 原先的 1 取最大值,于是就等于3,同时,②->①.sum = 3
之后继续深搜搞②③。。。emm就这样。
完整代码:(对了这题是多组输入的,不要问本菜鸡为什么debug了一晚上)
#include<cstdio>
#include<algorithm>
#include<cstring>
#define ll long long
#define debug(x) printf("----Line #x ----\n",x)
using namespace std;
struct node
{
int to,last;
ll w,sum;
}edge[20000+5];
int head[20000+5];
ll dp[10000+5];
int n,id;
void add(int u,int v,ll w){edge[id].to = v; edge[id].last = head[u]; edge[id].w = w; head[u] = id++;}
void dfs(int now,int pre)///pre表示now的父节点,向外拓展,不能回去
{
for (int i=head[now];i!=0;i=edge[i].last){
if (edge[i].to!=pre){
dfs(edge[i].to,now);
dp[now] = max(dp[now],dp[edge[i].to]+edge[i].w);
edge[i].sum = edge[i].w + dp[edge[i].to];
}
}
}
void treedp(int now,int son)
{
ll maxx=0;
for (int i = head[now];i!=0;i=edge[i].last){
if (edge[i].to!=son) maxx = max(maxx,edge[i].sum);
}
for (int i = head[son];i!=0;i=edge[i].last){
if (edge[i].to==now) {edge[i].sum = edge[i].w + maxx; break;}
}
for (int i = head[son];i!=0;i=edge[i].last){
dp[son] = max(dp[son],edge[i].sum);
if (edge[i].to!=now){
treedp(son,edge[i].to);
}
}
}
int main()
{
while (~scanf("%d",&n)){
///init
memset(head,0,sizeof head);
memset(dp,0,sizeof dp);
id = 1;
for (int i=2;i<=n;i++){
int v;
ll w;
scanf("%d %lld",&v,&w);
add(i,v,w);
add(v,i,w);
}
dfs(1,0);
for (int i=head[1];i!=0;i=edge[i].last){
treedp(1,edge[i].to);
}
/*
for (int i = 1;i<=n;i++){
for (int j=head[i];j;j=edge[j].last){
printf("%d->%d,w=%lld,sum=%lld ",i,edge[j].to,edge[j].w,edge[j].sum);
}
printf("dp[%d]=%lld\n",i,dp[i]);
}
*/
for (int i=1;i<=n;i++)
printf("%lld\n",dp[i]);
}
return 0;
}