这道题一开始在想可以枚举每个点对,尝试删除其间的边,因为有
O(n2)
个点对,所以要
O(1)
更新答案
后来发现,因为是树,所以只有
O(n)
个点对是有用的(这么显然的结论一开始没发现,看来还是我太弱了),然后就可以每次
O(n)
判断
首先定义在一棵树
x
中,对于点
对于一条边连接的两颗子树
在所有边的值取出最小的,就是答案。
对于一棵树,如何求出
F(A,x)
最小的
x
?
为什么
x
一定就是直径中点?
证明:如果
如果
因为总共有
O(n)
条边,找直径也是
O(n)
,所以整个算法的复杂度就是
O(n2)
,但找直径的常数较大,所以我加了个强力剪枝:只尝试计算原树直径上的边。如果删的边不在直径上,那么对答案肯定没影响。
可是,当原树为链时,剪枝就失效了,所以我对于链特判了一下,没有重新找直径。
最后,我去bzoj上交了一发,好像是rank2.rank1怎么那么快?
#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
typedef unsigned int ui;
struct edge{
int to,d;
}x;
const int N=5010;
vector<edge> g[N];
int n,i,u,v,d,ans,y,z,a[N],w,j,ff,sss,ttt,c[N],ww,cc[N],ss,tt,aa[N];
bool b[N],bb[N];
ui k;
void dfs1(int x,int dep,int fa){
if(dep>d)d=dep,v=x;
for(ui k=0;k<g[x].size();++k)
if(!b[g[x][k].to] && g[x][k].to!=fa)
dfs1(g[x][k].to,dep+g[x][k].d,x);
}
void dfs2(int x,int dep,int fa){
if(dep>y)y=dep,z=x;
for(ui k=0;k<g[x].size();++k)if(!b[g[x][k].to] && g[x][k].to!=fa)dfs2(g[x][k].to,dep+g[x][k].d,x);
}
bool dfs3(int x,int dep,int fa){
a[++w]=dep;
c[w]=x;
if(x==z)return 1;
for(ui k=0;k<g[x].size();++k)if(!b[g[x][k].to] && g[x][k].to!=fa)if(dfs3(g[x][k].to,dep+g[x][k].d,x))return 1;
--w;
return 0;
}
inline int got(int u){
memset(bb,0,sizeof bb);
dfs1(v=u,d=0,0);
dfs2(v,y=0,w=0);
dfs3(v,0,0);
for(j=1;j<=w;++j)
if(a[j]<=y>>1 && a[j+1]>=y>>1){
ff+=min(max(a[j],y-a[j]),max(a[j+1],y-a[j+1]));
break;
}
return y;
}
int main(){
scanf("%d",&n);
for(i=1;i<n;++i){
scanf("%d%d%d",&u,&v,&d);
g[u].push_back((edge){v,d});
g[v].push_back((edge){u,d});
}
ans=1<<30;
dfs1(u,d=0,0);
dfs2(v,y=0,w=0);
dfs3(v,0,0);
for(i=1,ww=w;i<=w;++i)cc[i]=c[i],aa[i]=a[i];
if(w==n){
for(i=2;i<=ww;++i){
ss=cc[i-1];
tt=cc[i];
ff=0;
y=a[ss];
for(j=1;j<=ss;++j)
if(a[j]<=y>>1 && a[j+1]>=y>>1){
ff+=min(max(a[j],y-a[j]),max(a[j+1],y-a[j+1]));
break;
}
y=a[n]-a[tt];
for(j=tt;j<=n;++j)
if(a[j]-a[tt]<=y>>1 && a[j+1]-a[tt]>=y>>1){
ff+=min(max(a[j]-a[tt],y-a[j]+a[tt]),max(a[j+1]-a[tt],y-a[j+1]+a[tt]));
break;
}
ans=min(ans,max(max(a[ss],a[n]-a[tt]),ff+aa[i]-aa[i-1]));
b[i]=0;
}
return printf("%d\n",ans),0;
}
for(i=2;i<=ww;++i){
ss=cc[i-1];
tt=cc[i];
ff=0;
b[ss]=1;
sss=got(tt);
b[ss]=0;
b[tt]=1;
ttt=got(ss);
ans=min(ans,max(max(sss,ttt),ff+aa[i]-aa[i-1]));
b[i]=0;
}
return printf("%d\n",ans),0;
}