题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=2196
参考了网上的代码才做的,其实这个想到树形dp还有点不容易。。虽然是被用烂的例题,但对新手还是。。嗯
树形dp,其实和普通的dp也是很像的,只是转化到了树上而已,最重要的是要找出父亲和儿子结点的转化关系,写出转移方程。可以说想到了就很简单吧,dp的第二维代表什么,也是需要讲究的。
这道题就是,dp的第二维用三个状态来表示,j=0时代表,从征程的下往上方向时,结点i可以到达的最长距离,j=1代表,从下往上结点i可以到达的次长距离,注意这个次长距离,同时还要记录每个结点最长链的子节点,那么问题来了,很明显只从下往上的最长是局部的,我们还有从上往下的距离(因为到一个点最远的距离只有两种可能,第一种就是从下往上,第二种就是经过根节点的与另一侧的最远距离)。从上往下的距离我们就要分两种情况来讨论了,我们面向的是每个儿子结点做讨论(因为根结点从上往下自然就是0了),对于每个儿子,如果他是该父亲最长边的子节点,那么他网上除了加上该边权外,只能加max(父亲的反向最远距离,父亲的正向次长距离),因为最长距离经过儿子这个点,所以只能加父亲的次长距离,如果不是这个父亲的最长边的子节点,那么就是边权+max(父亲的反向最远距离,父亲的正向最长距离)。
然后一直更新就好啦。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=200005;
const int mod=1e9+7;
struct edge{
int to,next;
ll w;
}e[maxn<<1];
int head[maxn],now,n,m,lonch[maxn];
ll dp[maxn][3];
void add(int u,int v,ll w){
e[now].to=v; e[now].w=w;
e[now].next=head[u],head[u]=now++;
}
ll dfs1(int u,int fa){
if(dp[u][0]>=0ll) return dp[u][0];
dp[u][0]= dp[u][1]=lonch[u]=dp[u][2]=0;
for(int i=head[u];~i;i=e[i].next){
int v=e[i].to; ll w=e[i].w;
if(v==fa) continue;
ll aim=dfs1(v,u)+w;
if(dp[u][0]<aim){
lonch[u]=v; dp[u][1]=dp[u][0];
dp[u][0]=aim;
}
else if(dp[u][1]<aim){
dp[u][1]=aim;
}
}
return dp[u][0];
}
void dfs2(int u,int fa){
for(int i=head[u];~i;i=e[i].next){
int v=e[i].to; ll w=e[i].w;
if(v==fa) continue;
if(v==lonch[u]){
dp[v][2]=max(dp[u][2],dp[u][1])+w;
}
else{
dp[v][2]=max(dp[u][2],dp[u][0])+w;
}
dfs2(v,u);
}
}
int main(){
while(~scanf("%d",&n)){
memset(head,-1,sizeof(head)); now=0;
memset(dp,-1,sizeof(dp));
for(int i=2;i<=n;i++){
int u;ll l;
scanf("%d%lld",&u,&l);
add(i,u,l); add(u,i,l);
}
dfs1(1,-1);
dfs2(1,-1);
for(int i=1;i<=n;i++)
printf("%lld\n",max(dp[i][0],dp[i][2]));
}
return 0;
}
附上另一份代码。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=200005;
const int mod=1e9+7;
struct edge{
int to,next;
ll w;
}e[maxn<<1];
int head[maxn],now,n,m;
ll dis1[maxn],maxx[maxn],dis2[maxn];
bool isl[maxn];
void add(int u,int v,ll w){
e[now].to=v; e[now].w=w;
e[now].next=head[u],head[u]=now++;
}
void dfs(int u,int fa,ll *dis){
for(int i=head[u];~i;i=e[i].next){
int v=e[i].to;
if(v==fa) continue;
dis[v]=dis[u]+e[i].w;
maxx[v]=dis[v];
dfs(v,u,dis);
maxx[u]=max(maxx[u],maxx[v]);
}
}
void deal(int u,int fa){
isl[u]=1;
for(int i=head[u];~i;i=e[i].next){
int v=e[i].to;
if(v==fa) continue;
deal(v,u);
}
}
int main(){
while(~scanf("%d",&n)){
if(n==1){
printf("0\n");
continue;
}
memset(isl,0,sizeof(isl));
memset(head,-1,sizeof(head)); now=0; memset(dis2,0,sizeof(dis2));
memset(dis1,0,sizeof(dis1));
for(int i=2;i<=n;i++){
int u;ll l;
scanf("%d%lld",&u,&l);
add(i,u,l); add(u,i,l);
}
dfs(1,-1,dis1);
int idfi=-1,idse=-1;
for(int i=head[1];~i;i=e[i].next){
int v=e[i].to;
if(idfi==-1||maxx[v]>maxx[idfi]){
idse=idfi; idfi=v;
}
else if (idse==-1||maxx[v]>maxx[idse]){
idse=v;
}
}
ll max1=maxx[idfi],max2;
if(idse!=-1)
max2=maxx[idse];
else max2=0;
int farest=-1;
for(int i=1;i<=n;i++){
if(dis1[i]==maxx[idfi]){
farest=i;
}
}
dfs(farest,-1,dis2);
deal(idfi,1);
for(int i=1;i<=n;i++){
if(isl[i]){
printf("%d\n",max(dis2[i],max2+dis1[i]));
}
else{
printf("%d\n",max(max1+dis1[i],dis2[i]));
}
}
}
return 0;
}