官方题解:
本题有多种解法。首先是点分治的思想,在点分治的时候,我们每一次选取一个中心,先统计过中心的路径最大值,然后删掉中心,递归处理其它子树。统计过中心的路径最大值,我们以中心为根深度搜索一遍,一个需要注意的地方是路径的两个端点不能在同一子树内,因为这样可能会重复统计。所以我们把路径按子树分类,然后点权排序以后更新路径按子树分类的最大值和次大值,之和与当前点权的乘积就是答案。
本题还可以用并查集来解决。将所有点按照权值从大到小排序,对于将当前点和与其相连的所有点依次合并到一个集合中。并查集需要维护当前集合中的最长路径长度和对应的两个端点。在合并两个集合后,最终集合的最长路一定只有两类情况:一类是其中一个集合的最长路,一共有 2 种;一类是由两个集合的最长路的端点互相连接而成,一共有 2×2=4 种。需要用到最近公共祖先的算法预处理求两点在树上的距离,离线处理即可。每次合并并查集之后用当前点的权值乘以最长路的总长度来更新最优结果即可。即使这个点不在当前合并后的集合的最长路上也是没有问题的,因为如果这样的话,必然已经在之前得到了对应的结果,这次合并不会对最终结果产生影响。
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<functional>
#include<cstring>
#define cl(x) memset(x,0,sizeof(x))
using namespace std;
typedef long long ll;
typedef pair<int,int> abcd;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
if (p1==p2) { p2=(p1=buf)+fread(buf,1,100000,stdin); if (p1==p2) return EOF; }
return *p1++;
}
inline void read(int &x){
char c=nc(),b=1;
for (;!(c>='0' && c<='9');c=nc()) if (c=='-') b=-1;
for (x=0;c>='0' && c<='9';x=x*10+c-'0',c=nc()); x*=b;
}
const int N=100005;
const int K=21;
struct edge{
int u,v,w,next;
}G[N<<1];
int head[N],inum;
inline void add(int u,int v,int w,int p){
G[p].u=u; G[p].v=v; G[p].w=w; G[p].next=head[u]; head[u]=p;
}
int father[N][K],depth[N];
ll dis[N];
#define V G[p].v
inline void dfs(int u,int fa){
father[u][0]=fa; depth[u]=depth[fa]+1;
for (int k=1;k<K;k++) father[u][k]=father[father[u][k-1]][k-1];
for (int p=head[u];p;p=G[p].next)
if (V!=fa)
dis[V]=dis[u]+G[p].w,dfs(V,u);
}
inline int LCA(int u,int v){
if (depth[u]<depth[v]) swap(u,v);
for (int k=K-1;~k;k--)
if ((depth[u]-depth[v])>>k&1)
u=father[u][k];
if (u==v) return u;
for (int k=K-1;~k;k--)
if (father[u][k]!=father[v][k])
u=father[u][k],v=father[v][k];
return father[u][0];
}
inline ll Dis(int u,int v){
return dis[u]+dis[v]-2*dis[LCA(u,v)];
}
int fat[N];
int s[N],t[N],mind[N]; ll len[N];
inline void init(int n){
for (int i=1;i<=n;i++) fat[i]=s[i]=t[i]=i,len[i]=0;
}
inline int Fat(int u){
return u==fat[u]?u:fat[u]=Fat(fat[u]);
}
int n,rnk[N];
abcd w[N];
ll Ans;
int main(){
int T,iu,iv,iw,x,y;
freopen("t.in","r",stdin);
freopen("t.out","w",stdout);
read(T);
while (T--){
read(n); init(n);
for (int i=1;i<=n;i++) read(w[i].first),w[i].second=i,mind[i]=w[i].first;
for (int i=1;i<n;i++)
read(iu),read(iv),read(iw),add(iu,iv,iw,++inum),add(iv,iu,iw,++inum);
dfs(1,0);
sort(w+1,w+n+1,greater<abcd>());
for (int i=1;i<=n;i++) rnk[w[i].second]=i;
Ans=0;
for (int i=1;i<=n;i++){
x=w[i].second;
for (int p=head[x];p;p=G[p].next){
y=Fat(V);
if (rnk[y]>=rnk[x]) continue;
ll maxv=-1,l,r,tmp;
if (len[x]>maxv) maxv=len[x],l=s[x],r=t[x];
if (len[y]>maxv) maxv=len[y],l=s[y],r=t[y];
if ((tmp=Dis(s[x],s[y]))>maxv) maxv=tmp,l=s[x],r=s[y];
if ((tmp=Dis(t[x],t[y]))>maxv) maxv=tmp,l=t[x],r=t[y];
if ((tmp=Dis(s[x],t[y]))>maxv) maxv=tmp,l=s[x],r=t[y];
if ((tmp=Dis(t[x],s[y]))>maxv) maxv=tmp,l=t[x],r=s[y];
fat[y]=x; s[x]=l; t[x]=r; len[x]=maxv;
mind[x]=min(mind[x],mind[y]);
}
Ans=max(Ans,len[x]*w[i].first);
}
printf("%lld\n",Ans);
cl(head); inum=0;
}
return 0;
}