注意一下在双指针的时候那个 mxdep 初始值应该是 -1,因为可能不满足最小值.
code:
#include <cstdio>
#include <string>
#include <map>
#include <cstring>
#include <vector>
#include <set>
#include <algorithm>
#define N 2000007
#define fi first
#define se second
#define ll long long
#define inf 0x3f3f3f3f
#define mk(x,y) make_pair(x,y)
using namespace std;
namespace IO {
void setIO(string s)
{
string in=s+".in";
string out=s+".out";
freopen(in.c_str(),"r",stdin);
// freopen(out.c_str(),"w",stdout);
}
};
int n,tot;
int val[N];
struct Edge {
int nex,to,w;
}e[N<<1],edge[N<<1];
int edges1=1,edges2=1,hd[N],pre[N];
ll ans;
void add(int x,int y,int z)
{
e[++edges1].nex=hd[x],hd[x]=edges1,e[edges1].to=y,e[edges1].w=z;
}
void add_c(int x,int y)
{
edge[++edges2].nex=pre[x],pre[x]=edges2,edge[edges2].to=y;
}
void Rebuild(int x,int fa)
{
int ff=0;
for(int i=pre[x];i;i=edge[i].nex)
{
int y=edge[i].to;
if(y==fa) continue;
if(!ff)
{
add(x,y,1);
add(y,x,1);
ff=x;
}
else
{
int tmp=++tot;
val[tmp]=val[x];
add(ff,tmp,0),add(tmp,ff,0);
add(tmp,y,1),add(y,tmp,1);
ff=tmp;
}
Rebuild(y,x);
}
}
pair<int,int>ls[N],rs[N];
int lsc,rsc;
bool cmp(pair<int,int>A,pair<int,int>B)
{
if(A.fi==B.fi) return A.se<B.se;
else return A.fi<B.fi;
}
bool vis[N<<1];
int size[N],mx;
int rt1,rt2,ed,totsz;
void dfs1(int x,int fa)
{
size[x]=1;
for(int i=hd[x];i;i=e[i].nex)
{
int y=e[i].to;
if(y==fa) continue;
if(vis[i]) continue;
dfs1(y,x);
int now=max(size[y],totsz-size[y]);
if(now<mx)
{
mx=now;
rt1=x,rt2=y;
ed=i;
}
size[x]+=size[y];
}
}
void dfs2(int x,int fa,int mi,int dep,int typ)
{
if(typ==1)
{
ls[++lsc]=mk(mi,dep);
}
else
{
rs[++rsc]=mk(mi,dep);
}
for(int i=hd[x];i;i=e[i].nex)
{
int y=e[i].to;
if(y==fa||vis[i]) continue;
dfs2(y,x,min(mi,val[y]),dep+e[i].w,typ);
}
}
void Divide_And_Conquer(int x)
{
if(totsz==1) return;
rt1=rt2=ed=0,mx=inf;
dfs1(x,0);
vis[ed]=vis[ed^1]=1;
lsc=rsc=0;
dfs2(rt1,0,val[rt1],0,0);
dfs2(rt2,0,val[rt2],0,1);
sort(ls+1,ls+1+lsc,cmp);
sort(rs+1,rs+1+rsc,cmp);
int mxdep,ptr;
mxdep=-1,ptr=rsc;
for(int i=lsc;i>=1;--i)
{
while(ptr&&rs[ptr].fi>=ls[i].fi)
{
mxdep=max(mxdep,rs[ptr].se),--ptr;
}
ans=max(ans,(ll)((ll)mxdep+e[ed].w+ls[i].se+1)*ls[i].fi);
}
mxdep=-1,ptr=lsc;
for(int i=rsc;i>=1;--i)
{
while(ptr&&ls[ptr].fi>=rs[i].fi)
{
mxdep=max(mxdep,ls[ptr].se);
--ptr;
}
ll tmp=(ll)(1ll*mxdep+e[ed].w+rs[i].se+1)*rs[i].fi;
ans=max(ans,tmp);
}
int szrt1=totsz-size[rt2];
int szrt2=size[rt2];
int tmprt1=rt1,tmprt2=rt2;
totsz=szrt1;
Divide_And_Conquer(tmprt1);
totsz=szrt2;
Divide_And_Conquer(tmprt2);
}
int main()
{
// IO::setIO("input");
int i,j;
scanf("%d",&n);
for(i=1;i<=n;++i) scanf("%d",&val[i]),ans=max(ans,(ll)val[i]);
int x,y;
for(i=1;i<n;++i)
{
scanf("%d%d",&x,&y);
add_c(x,y),add_c(y,x);
}
tot=n;
Rebuild(1,0);
totsz=tot;
Divide_And_Conquer(1);
printf("%lld\n",ans);
return 0;
}