思路十分简单,答案只有 3 种可能,但是有一些细节需要额外注意一下.
code:
#include <bits/stdc++.h>
#define N 300002
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
int val[N],hd[N],to[N<<1],nex[N<<1],d1[N],d2[N],n,edges,maxx,mx,m2,cnt,uu;
void add(int u,int v)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
void dfs(int u,int ff)
{
if(val[u]==mx) d1[u]=0, uu=u;
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(v==ff) continue;
dfs(v,u);
if(d1[v]+1>d1[u])
{
d2[u]=d1[u],d1[u]=d1[v]+1;
}
else if(d1[v]+1>d2[u]) d2[u]=d1[v]+1;
}
maxx=max(d1[u]+d2[u], maxx);
}
int main()
{
int i,j;
// setIO("input");
mx=-1000300000;
m2=mx;
scanf("%d",&n);
for(i=1;i<=n;++i)
{
scanf("%d",&val[i]),mx=max(mx,val[i]);
}
for(i=1;i<=n;++i) if(val[i]<mx) m2=max(m2, val[i]);
for(i=1;i<=n;++i) if(val[i]==m2) ++cnt;
for(i=1;i<n;++i)
{
int u,v;
scanf("%d%d",&u,&v),add(u,v),add(v,u);
}
memset(d1,-0x3f,sizeof(d1));
memset(d2,-0x3f,sizeof(d2));
dfs(1,0);
if(maxx==0)
{
if(m2!=mx-1)
printf("%d\n",mx);
else
{
for(int i=hd[uu];i;i=nex[i])
{
int v=to[i];
if(val[v]==m2) --cnt;
}
if(cnt) printf("%d\n",mx+1);
else printf("%d\n",mx);
}
}
else if(maxx<=2) printf("%d\n",mx+1);
else printf("%d\n",mx+2);
return 0;
}