赛时没有考虑清楚细节,死活做不出来。
赛后看了题解,修正了做法:
f[u]:以u为根的子树的直径
h[u]:除掉u为根的子树,剩下的子树的直径
最后for循环扫一遍,f[]与h[]取max即可
赛时就少了h数组啊!!
#pragma comment(linker,"/STACK:1024000000,1024000000")
#include<algorithm>
#include<cstdio>
#include<cstring>
#define N 100100
using namespace std;
struct edge{
int to,next,w;
}e[N*2];
struct{
int a[3];
int from[3];
}ans[N],ff[N];
int n,o;
int fa[N],head[N],f[N],h[N],ed[N];
bool v[N];
void add(int x,int y,int w)
{
e[o].to=y;
e[o].w=w;
e[o].next=head[x];
head[x]=o++;
}
void update_ans(int u,int v,int w)
{
if (w>ans[u].a[0])
{ ans[u].a[2]=ans[u].a[1]; ans[u].from[2]=ans[u].from[1];
ans[u].a[1]=ans[u].a[0]; ans[u].from[1]=ans[u].from[0];
ans[u].a[0]=w; ans[u].from[0]=v;
}
else if (w>ans[u].a[1])
{ ans[u].a[2]=ans[u].a[1]; ans[u].from[2]=ans[u].from[1];
ans[u].a[1]=w; ans[u].from[1]=v;
}
else if (w>ans[u].a[2])
{ ans[u].a[2]=w; ans[u].from[2]=v;
}
}
void update_ff(int u,int v,int w)
{ if (w>ff[u].a[0])
{ ff[u].a[1]=ff[u].a[0]; ff[u].from[1]=ff[u].from[0];
ff[u].a[0]=w; ff[u].from[0]=v;
}
else if (w>ff[u].a[1])
{ ff[u].a[1]=w; ff[u].from[1]=v;
}
}
void DFS0(int now)
{ v[now]=1;
for (int k=head[now];k!=-1;k=e[k].next)
if (!v[e[k].to])
{int j=e[k].to;
ed[j]=k; fa[j]=now;
DFS0(j);
update_ans(now,j,ans[j].a[0]+1);
update_ff(now,j,f[j]);
f[now]=max(f[now],f[j]);
}
f[now]=ans[now].a[0]+ans[now].a[1];
}
void DFS1(int now)
{ v[now]=1;
for (int k=head[now];k!=-1;k=e[k].next)
if (!v[e[k].to])
{int j=e[k].to;
h[j]=max(h[j],h[now]);
if (j==ff[now].from[0]) h[j]=max(h[j],ff[now].a[1]);else h[j]=max(h[j],ff[now].a[0]);
update_ans(j,now,(ans[now].from[0]!=j)?ans[now].a[0]+1:ans[now].a[1]+1);
if (j==ans[now].from[0]) h[j]=max(h[j],ans[now].a[1]+ans[now].a[2]);
else if (j==ans[now].from[1]) h[j]=max(h[j],ans[now].a[0]+ans[now].a[2]);
else h[j]=max(h[j],ans[now].a[0]+ans[now].a[1]);
DFS1(j);
}
}
void doit()
{ int x,y,w;
memset(head,255,sizeof(head)); o=0;
memset(ans,0,sizeof(ans));
memset(ff,0,sizeof(ff));
scanf("%d",&n);
for (int i=1;i<=n-1;i++)
{scanf("%d%d%d",&x,&y,&w);
add(x,y,w);add(y,x,w);
}
memset(f,0,sizeof(f));
memset(h,0,sizeof(h));
memset(v,0,sizeof(v));
DFS0(1);
h[1]=0;
memset(v,0,sizeof(v));
DFS1(1);
int anss=-1;
int ans2,tmp;
for (int i=2;i<=n;i++)
{tmp=e[ed[i]].w*max(f[i],h[i]);
if (tmp<anss||anss==-1) {anss=tmp;ans2=ed[i];}
else if (tmp==anss) ans2=min(ans2,ed[i]);
}
printf("%d\n",ans2/2+1);
}
int main()
{ int cas,i=0;
scanf("%d",&cas);
while (cas--) {i++;printf("Case #%d: ",i);doit();}
return 0;
}