这题写的蛋疼,如果不是吴大爷给我讲了一下也许就gg了。。。
边分治为了防止被菊花卡需要建一些虚点和虚边,把一个点的儿子搞成完全二叉树,像这样:
其中红色的是虚点和虚边。其中虚点的点权是x的点权。
这题路径长度的定义是路径上点的个数, 如果这样的话对于虚点的权值无法处理。
把路径长度转移到边上,原边的长度为1,虚边的长度为0。
分治时找一条两边size最大值最小的边分治。强制边两边的两个终点不为虚点。那么路径长度就为两边长度和+1+当前分治边的长度。
将两边的点按路径上最小点权排个序,互相更新一下就行了。
指针扫能做到O(nlogn),我写的O(nlog^2n)
#include <bits/stdc++.h>
using namespace std;
#define N 110000
#define ll long long
const int mx=65536;
int n,tot,tn;
ll ans;
int v[N],head[N],nex[N<<1],to[N<<1],val[N<<1];
vector<int> a[N];
int size[N],root,sum,vis[N],f[N],top[2];
struct node
{
int v,l;
node(){}
node(int v,int l):v(v),l(l){}
friend bool operator < (const node &r1,const node &r2)
{
if(r1.v==r2.v)return r1.l<r2.l;
return r1.v<r2.v;
};
}st[2][N];
void add(int x,int y,int z)
{
tot++;
nex[tot]=head[x];head[x]=tot;
to[tot]=y;val[tot]=z;
}
void dfs1(int x,int y)
{
for(int i=head[x];i;i=nex[i])
if(to[i]!=y)
{
dfs1(to[i],x);
a[x].push_back(to[i]);
}
}
void dfs2(int x,int y)
{
size[x]=1;
for(int i=head[x];i;i=nex[i])
if(to[i]!=y&&!vis[i>>1])
{
dfs2(to[i],x);
size[x]+=size[to[i]];
}
}
void dfs3(int x,int y)
{
for(int i=head[x];i;i=nex[i])
if(to[i]!=y&&!vis[i>>1])
{
f[i>>1]=max(size[to[i]],sum-size[to[i]]);
root=f[i>>1]<f[root] ? i>>1:root;
dfs3(to[i],x);
}
}
void dfs4(int x,int y,int v1,int l,int type)
{
v1=min(v1,v[x]);
if(x<=tn)st[type][++top[type]]=node(v1,l);
for(int i=head[x];i;i=nex[i])
if(to[i]!=y&&!vis[i>>1])
dfs4(to[i],x,v1,l+val[i],type);
}
int cal(int x)
{
vis[x]=1;top[0]=top[1]=0;
dfs4(to[x<<1],0,mx,0,0);
dfs4(to[x<<1|1],0,mx,0,1);
for(int i=0;i<=1;i++)
{
sort(st[i]+1,st[i]+1+top[i]);
for(int j=top[i]-1;j>=1;j--)
st[i][j].l=max(st[i][j].l,st[i][j+1].l);
for(int j=1;j<=top[i^1];j++)
if(st[i][top[i]].v>=st[i^1][j].v)
{
int t=lower_bound(st[i]+1,st[i]+1+top[i],node(st[i^1][j].v,0))-st[i];
ans=max(ans,(ll)(st[i][t].l+st[i^1][j].l+1+val[x<<1])*st[i^1][j].v);
}
}
for(int i=0;i<=1;i++)
{
int t=to[(x<<1)+i];
dfs2(t,0);sum=size[t];
root=0;dfs3(t,0);
if(root)cal(root);
}
}
int main()
{
//freopen("tt.in","r",stdin);
scanf("%d",&n);tot=1;tn=n;
for(int i=1;i<=n;i++)
scanf("%d",&v[i]),ans=max(ans,(ll)v[i]);
for(int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y,1);add(y,x,1);
}
dfs1(1,0);
tot=1;memset(head,0,sizeof(head));
for(int x=1;x<=n;x++)
{
int t=a[x].size();
if(t==0)continue;
if(t==1)
{
add(x,a[x][0],1);add(a[x][0],x,1);
continue;
}
for(int i=2;i<t<<1;i<<=1)
for(int j=0;j<t;j+=i)
{
add(++n,a[x][j],i==2 ? 1:0);
add(a[x][j],n,i==2 ? 1:0);
a[x][j]=n;v[n]=v[x];
if(j+(i>>1)<t)
{
add(n,a[x][j+(i>>1)],i==2 ? 1:0);
add(a[x][j+(i>>1)],n,i==2 ? 1:0);
}
}
add(n,x,0);add(x,n,0);
}
sum=n;dfs2(1,0);
root=0;f[0]=n+1;dfs3(1,0);
cal(root);
printf("%lld\n",ans);
return 0;
}