最长道路tree
Description
H城很大,有N个路口(从1到N编号),路口之间有N-1边,使得任意两个路口都能互相到达,这些道路的长度我们视作一样。每个路口都有很多车辆来往,所以每个路口i都有一个拥挤程度v[i],我们认为从路口s走到路口t的痛苦程度为s到t的路径上拥挤程度的最小值,乘上这条路径上的路口个数所得的积。现在请你求出痛苦程度最大的一条路径,你只需输出这个痛苦程度。
Simple Description
给定一棵N个点的树,求树上一条链使得链的长度乘链上所有点中的最小权值所得的积最大。其中链长度定义为链上点的个数。
Input
第一行N
第二行N个数分别表示1~N的点权v[i]
接下来N-1行每行两个数x、y,表示一条连接x和y的边
Output
一个数,表示最大的痛苦程度。
Sample Input
3
5 3 5
1 2
1 3
Sample Output
10
样例解释
选择从1到3的路径,痛苦程度为min(5,5)*2=10
HINT
100%的数据n<=50000
其中有20%的数据树退化成一条链
所有数据点权<=65536
Hint:建议答案使用64位整型
思路
首先我们看一下数据范围,想一想点分治似乎就哈哈了,所以我们需要换一个思路。我们考虑一下,知道只有路径上的最小值才能对答案有贡献,所以我们可以把点的权值从大到小排序,这样我们就可以在插点的同时,维护经过当前点的最长路径,从而更新路径最大值就可以了。为什么呢?因为我们是按照权值由大到小的顺序进行的建树,所以每一次统计路径时,当前点就是最小值。
下面我们想,用什么维护路径长度呢?树链剖分。用什么维护最长长度呢?树的直径加并查集。所以这些放在一起,就是AC。
代码
#include <stdio.h>
#include <algorithm>
using namespace std;
#define N 50001
int n,idx,cnt;
int head[N];
int to[N<<1];
int nxt[N<<1];
int val[N],son[N];
int fa[N],top[N];
int f[N],need[N];
int level[N],size[N];
long long ans;int dis[N];
int root1[N],root2[N];
bool vis[N];
bool cmp(const int &a,const int &b)
{return val[a]>val[b];}
int find_anc(int x)
{return (f[x]==x)?x:f[x]=find_anc(f[x]);}
void add(int a,int b)
{
nxt[++idx]=head[a];
head[a]=idx;
to[idx]=b;
}
void dfs(int p,int from)
{
level[p]=level[from]+1;
size[p]=1,fa[p]=from;
for(int i=head[p];i;i=nxt[i])
if(to[i]!=from)
{
dis[to[i]]=dis[p]+1;dfs(to[i],p);
size[p]+=size[to[i]];
if(size[son[p]]<size[to[i]]) son[p]=to[i];
}
}
void dfs2(int p,int from)
{
if(son[p]) dfs2(son[p],from);
top[p]=from;
for(int i=head[p];i;i=nxt[i])
if(to[i]!=fa[p]&&to[i]!=son[p])
dfs2(to[i],to[i]);
}
int find_lca(int a,int b)
{
while(top[a]!=top[b])
{
if(level[top[a]]>level[top[b]])
swap(a,b);
b=fa[top[b]];
}
return (level[a]<level[b])?a:b;
}
int find_dis(int a,int b)
{
int tmp=find_lca(a,b);
return dis[a]+dis[b]-2*dis[tmp];
}
void merge(int x,int y)
{
int fx=find_anc(x);
int fy=find_anc(y);
if(fx==fy) return;
f[fy]=fx;int &p1=root1[fx],&p2=root2[fx];
int r1=root1[fx],r2=root2[fx],r3=root1[fy],r4=root2[fy];
int l1=find_dis(r1,r2),l2=find_dis(r1,r3),l3=find_dis(r1,r4);
int l4=find_dis(r2,r3),l5=find_dis(r2,r4),l6=find_dis(r3,r4);
int mx=max(l1,max(l2,max(l3,max(l4,max(l5,l6)))));
if(l1==mx) p1=r1,p2=r2;
else if(l2==mx) p1=r1,p2=r3;
else if(l3==mx) p1=r1,p2=r4;
else if(l4==mx) p1=r2,p2=r3;
else if(l5==mx) p1=r2,p2=r4;
else if(l6==mx) p1=r3,p2=r4;
ans=max(ans,1ll*(mx+1)*val[x]);
}
void add_point(int p)
{
vis[p]=true;
for(int i=head[p];i;i=nxt[i])
if(vis[to[i]]) merge(p,to[i]);
}
int main()
{
//freopen("Choose.in","r",stdin);
//freopen("Choose.out","w",stdout);
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d",&val[i]);
need[i]=f[i]=root1[i]=root2[i]=i;
}
sort(need+1,need+n+1,cmp);
for(int i=1;i<n;i++)
{
int a,b;
scanf("%d%d",&a,&b);
add(a,b),add(b,a);
}
dfs(1,0),dfs2(1,1);
for(int i=1;i<=n;i++) add_point(need[i]);
printf("%I64d",ans);
}