题意:树上随便选两个路径,有点权,只被一个路径覆盖的点权有效,求最大值
两种情况 一种是只有一个交点,一种无交点,无交点我们可以通过断一条树边求分别两边的最大值,记为f与g f为断u与fa后 u子树的最大值,这个比较简单,就是套路的最大次大, 然后是g
我是打死想不到O(n)的转移,没脑子 ,考虑g什么时候最简单,可以初始化 答案是 为1连着的节点时g最简单 此时情况就是断掉这个子树,将其他的最长次长链求一个和,然后取其他f[v]最大值即可,这个东西用multiset维护一下
接着是转移考虑父亲到子节点的变化 我们会多出来除了一个节点外的其他子嗣,需要将这些子嗣最长的直的链跟往上最大的链取一个最大值,这个往上最大的链子其实不一定就只到一,它可能穿过其他儿子 ,这个东西又可以用multiset维护 然后g[v]大概就完了
最后对于空心的四条链子 这个东西还是要用multiset 搞一下
最后给个图吧 对于g的转移没图还是说不清的
如图,它会多出来除了v以下的x其他的子嗣
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N=2e5+10;
int head[N],a[N],f[N],g[N],dp[N],dp2[N];
int n;int ans=0;
int k;vector<int> s[N];vector<int> s1[N];
struct p{
int next,v;
}st[N*2];
void add(int a,int b){
st[++k].next=head[a];
st[k].v=b;
head[a]=k;
}
int longest[N];
int dfs0(int x,int fa){
int mx=0;
for(int i=head[x];i;i=st[i].next){
int v=st[i].v;
if(v==fa)continue;
int tem=dfs0(v,x)+a[v];
mx=max(tem,mx);
}
return longest[x]=mx;
}
int dfs(int x,int fa,int val){
multiset<int> s3;multiset<int>::iterator it;
if(x==1){
for(int i=head[x];i;i=st[i].next){
int v=st[i].v;
if(v==fa)continue;
s3.insert(a[v]+longest[v]);
}}
int mx=0;s[x].push_back(val);
for(int i=head[x];i;i=st[i].next){
int v=st[i].v;
if(v==fa)continue;
if(x==1){
s3.erase(s3.find(a[v]+longest[v]));
it=s3.end();
if(s3.size()>=1)val=*(--it);
s3.insert(a[v]+longest[v]);
}
int t2=a[v]+dfs(v,x,val+a[x]);
mx=max(mx,t2);
s[x].push_back(t2);
}
sort(s[x].rbegin(),s[x].rend());int tot=0;
int m=s[x].size();
if(m>=4)
for(int i=0;i<4;i++)tot+=s[x][i];
if(m==3)
for(int i=0;i<3;i++)tot+=s[x][i];
if(m==2)
for(int i=0;i<2;i++)tot+=s[x][i];
if(m<=3)tot+=a[x];
ans=max(ans,tot);
return mx;
}
int dfs2(int x,int fa){
int mx=0;f[x]=a[x];
for(int i=head[x];i;i=st[i].next){
int v=st[i].v;
if(v==fa)continue;int tem=longest[v]+a[v];
dfs2(v,x);
f[x]=max(f[x],f[v]);
f[x]=max(f[x],mx+tem+a[x]);
mx=max(mx,tem);
}
return mx;
}
int t[N];int top=0;
int dfs3(int x,int fa,int val){
//printf("x=%d %d val=%d\n",x,fa,val);
multiset<int> s3;multiset<int>::iterator it;
for(int i=head[x];i;i=st[i].next){
int v=st[i].v;
if(v==fa)continue;
s3.insert(f[v]);
}
for(int i=head[x];i;i=st[i].next){
int v=st[i].v;
if(v==fa)continue;
g[v]=max(g[x],g[v]);
s3.erase(s3.find(f[v]));
if(s3.size()>=1){
it=s3.end();it--;
g[v]=max(g[v],*it);}
g[v]=max(g[v],a[x]);
s3.insert(f[v]);
}
s3.clear();
for(int i=head[x];i;i=st[i].next){
int v=st[i].v;
if(v==fa)continue;
s3.insert(longest[v]+a[v]);
}s3.insert(val);
for(int i=head[x];i;i=st[i].next){
int v=st[i].v;
if(v==fa)continue;
s3.erase(s3.find(longest[v]+a[v]));
int all=0;int k3=2;it=s3.end();it--;
for(;;it--){
k3--;all+=*it;
if(k3==0||it==s3.begin())break;
}
// if(x==2)printf("%d %d\n",all,a[x]);
g[v]=max(g[v],all+a[x]);
it=s3.end();
int t2=max(val,*(--it));
s3.insert(longest[v]+a[v]);
dfs3(v,x,t2+a[x]);
}
}
signed main()
{
// freopen("in.in", "r", stdin); //读入数据生成器造出来的数据
//freopen("baoli.txt", "w", stdout); //输出答案
cin>>n;
for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
for(int i=1;i<n;i++){
int aa,b;
scanf("%lld%lld",&aa,&b);
add(aa,b);
add(b,aa);
}//add(1,n+1);add(n+1,1);
dfs0(1,0);
dfs(1,0,0);
//part 1 end;
dfs2(1,0);
//part 2 end;
g[1]=0;a[n+1]=0;
//part 3 end
dfs3(1,0,0);
if(n==1){
printf("0\n");
return 0;
}
for(int i=2;i<=n;i++)ans=max(ans,f[i]+g[i]);
cout<<ans<<endl;
}