这个思路还是非常巧妙的.
困难在于我们需要同时维护以 $x$ 为分治中心,延伸出颜色相同/不同的最大值.
不同的话直接将权和相加,相同的话还需要减掉重复部分,这就比较难办.
但是我们发现,当以 $x$ 为分治中心时,$x$ 每一个儿子为根的子树的延伸颜色都是相同的.
所以我们可以将每一个点的所有儿子按照延伸颜色排序,然后维护两颗线段树:相同与不同.
当我们在点分治时处理到下一个儿子,而下一个儿子与当前儿子颜色不同时,用线段树合并的方式将相同线段树合并到不同即可.
code:
#include <cstdio>
#include <string>
#include <vector>
#include <cstring>
#include <algorithm>
#define N 200006
#define ll long long
#define inf 1000000009
using namespace std;
void setIO(string s)
{
freopen((s+".in").c_str(),"r",stdin);
// freopen((s+".out").c_str(),"w",stdout);
}
int answer=-inf,mp=0;
namespace seg
{
int tot;
struct node { int ls,rs,maxx; }s[N*50];
void clr()
{
for(int i=1;i<=tot;++i) s[i].ls=s[i].rs=0,s[i].maxx=-inf;
tot=0;
}
int merge(int x,int y)
{
if(!x||!y) return x+y;
s[x].maxx=max(s[x].maxx,s[y].maxx);
s[x].ls=merge(s[x].ls,s[y].ls);
s[x].rs=merge(s[x].rs,s[y].rs);
return x;
}
void update(int &x,int l,int r,int p,int v)
{
if(!x) x=++tot,s[x].maxx=-inf;
s[x].maxx=max(s[x].maxx,v);
if(l==r) return;
int mid=(l+r)>>1;
if(p<=mid) update(s[x].ls,l,mid,p,v);
else update(s[x].rs,mid+1,r,p,v);
}
int query(int x,int l,int r,int L,int R)
{
if(!x||L>R) return -inf;
if(l>=L&&r<=R) return s[x].maxx;
int mid=(l+r)>>1,re=-inf;
if(L<=mid) re=max(re,query(s[x].ls,l,mid,L,R));
if(R>mid) re=max(re,query(s[x].rs,mid+1,r,L,R));
return re;
}
};
int n,m,_min,_max,root,sn,tot;
int val[N],size[N],mx[N],vis[N];
struct Edge
{
int to,w;
Edge(int to=0,int w=0):to(to),w(w){}
};
struct Dis
{
int d,v;
Dis(int d=0,int v=0):d(d),v(v){}
}A[N];
vector<Edge>F[N];
vector<Edge>G[N];
bool cmp(Edge a,Edge b) { return a.w<b.w; }
void getroot(int u,int ff)
{
size[u]=1,mx[u]=0;
for(int i=0;i<G[u].size();++i)
{
int v=G[u][i].to;
if(vis[v]||v==ff) continue;
getroot(v,u),size[u]+=size[v];
mx[u]=max(mx[u],size[v]);
}
mx[u]=max(mx[u],sn-size[u]);
if(mx[u]<mx[root]) root=u;
}
void dfs(int u,int ff,int d,int v,int pr)
{
A[++tot]=Dis(d,v);
for(int i=0;i<G[u].size();++i)
{
int y=G[u][i].to;
if(vis[y]||y==ff) continue;
dfs(y,u,d+1,v+val[G[u][i].w]*(G[u][i].w!=pr),G[u][i].w);
}
}
void solve(int u)
{
int rt_diff=0,rt_same=0;
F[u].clear();
for(int i=0;i<G[u].size();++i) if(!vis[G[u][i].to]) F[u].push_back(G[u][i]);
for(int i=0;i<F[u].size();++i)
{
int v=F[u][i].to;
tot=0;
dfs(v,u,1,val[F[u][i].w],F[u][i].w);
for(int j=1;j<=tot;++j)
{
if(A[j].d>_max) continue;
if(A[j].d>=_min&&A[j].d<=_max)
{
answer=max(answer,A[j].v),mp=max(mp,A[j].v);
}
answer=max(answer,A[j].v+seg::query(rt_diff,1,n,max(1,_min-A[j].d),_max-A[j].d));
answer=max(answer,A[j].v+seg::query(rt_same,1,n,max(1,_min-A[j].d),_max-A[j].d)-val[F[u][i].w]);
}
for(int j=1;j<=tot;++j) seg::update(rt_same,1,n,A[j].d,A[j].v);
if(i<F[u].size()-1&&F[u][i+1].w!=F[u][i].w)
{
rt_diff=seg::merge(rt_diff,rt_same);
rt_same=0;
}
}
seg::clr();
vis[u]=1;
for(int i=0;i<F[u].size();++i)
{
int v=F[u][i].to;
sn=size[v],root=0,getroot(v,u),solve(root);
}
}
int main()
{
// setIO("input");
int i,j;
scanf("%d%d%d%d",&n,&m,&_min,&_max);
for(i=1;i<=m;++i) scanf("%d",&val[i]);
for(i=1;i<n;++i)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
G[x].push_back(Edge(y,z));
G[y].push_back(Edge(x,z));
}
for(i=1;i<=n;++i) sort(G[i].begin(),G[i].end(),cmp);
sn=n,mx[root=0]=N,getroot(1,0),solve(root),printf("%d\n",answer);
return 0;
}