题目链接:https://ruanx.pw/bzojch/p/2870.html?tdsourcetag=s_pcqq_aiomsg
用map记录子树中每个最小路径长度的最大边长,定义map从大到小排序,然后用双指针的方法进行路径合并。
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/hash_policy.hpp>
#define rep(i, a, b) for(int i = (a); i <= (b); i++)
#define pb push_back
#define ll long long
using namespace std;
using namespace __gnu_pbds;
const int N = 5e4+1000;
int n,k;
struct node {
int v,nxt;
}edge[2*N];
int tot,head[N];
void ae(int u,int v) {
edge[++tot] = node{v,head[u]};
head[u] = tot;
}
void init(int n) {
tot = 0;
rep(i, 1, n) head[i] = -1;
}
int siz[N],Root,wt[N],Tsiz;
bool vis[N];
ll ans;
ll s[N];
map<ll,ll,greater<ll> > Tnum,num;
void GetRoot(int u,int f) {
siz[u] = 1;
wt[u] = 0;
for(int i = head[u]; ~i ; i = edge[i].nxt) {
int v = edge[i].v;
if(v==f||vis[v]) continue;
GetRoot(v,u);
siz[u] += siz[v];
wt[u] = max(wt[u],siz[v]);
}
wt[u] = max(wt[u],Tsiz-siz[u]);
if(wt[Root]>wt[u]) Root = u;
}
void dfs(int u,int f,ll dis,ll minn) {
num[minn] = max(num[minn],dis);
for(int i = head[u]; ~i ; i = edge[i].nxt) {
int v = edge[i].v;
if(v==f||vis[v]) continue;
dfs(v,u,dis+1,min(minn,s[v]));
}
}
void calc(int u) {
Tnum.clear(); //注意,这里并没有考虑u->u这条路径
Tnum[s[u]] = 1;
ans = max(ans,s[u]);
for(int i = head[u]; ~i ; i = edge[i].nxt) {
int v = edge[i].v;
if(vis[v]) continue;
num.clear();
dfs(v,u,2,min(s[u],s[v]));
auto it = num.begin();
ll maxx = 0;
for(auto x:Tnum) {
while(it!=num.end() && (*it).first>=x.first) {
maxx = max((*it).second,maxx);
it++;
}
ans = max(ans,x.first*(x.second+maxx-1));
}
it = Tnum.begin();
maxx = 0;
for(auto x:num) {
while(it!=Tnum.end() && (*it).first>=x.first) {
maxx = max((*it).second,maxx);
it++;
}
ans = max(ans,x.first*(x.second+maxx-1));
}
for(auto x:num)
Tnum[x.first] = max(Tnum[x.first],x.second);
}
}
void divide(int u) {
calc(u);
vis[u] = 1; //删掉该点
for(int i = head[u]; ~i ; i = edge[i].nxt) {
int v = edge[i].v;
if(vis[v]) continue;
Root = 0,Tsiz = siz[v];
GetRoot(v,0);
divide(Root);
}
}
int main() {
// freopen("a.txt","r",stdin);
//ios::sync_with_stdio(0);
scanf("%d",&n);
rep(i, 1, n) scanf("%lld",&s[i]);
init(n);
rep(i, 1, n-1) {
int u,v;
scanf("%d%d",&u,&v);
ae(u,v);
ae(v,u);
}
rep(i, 1, n) vis[i] = 0;
wt[0] = 1e9,Tsiz = n,GetRoot(1,0),ans = 0;
divide(Root);
printf("%lld",ans);
return 0;
}