题意:
有N个顶点的树,节点间有权值, 节点分为黑点和白点。 找一条最长路径使得 路径上黑点数量不超过K个。
思路:
设当前节点x到根的路径上黑色节点数量为deep[x],路径长度为dis[x]
将子节点按照最大deep倒序处理,利用启发式合并使得合并复杂度降为nlogn
#include<bits/stdc++.h>
template <class T1, class T2>inline void gmax(T1 &a, T2 b) { if (b>a)a = b; }
template <class T1, class T2>inline void gmin(T1 &a, T2 b) { if (b<a)a = b; }
using namespace std;
const int N=2e5+100;
int K,head[N],tot;
struct Edge{
int to,next,w;
}e[N*2];
int size[N],f[N],Count,root,Color[N];
bool Del[N];
void addedge(int from,int to,int w){
e[tot]=(Edge){to,head[from],w};
head[from]=tot++;
}
void getroot(int u,int fa){
size[u]=1,f[u]=0;
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
if(v!=fa&&!Del[v]){
getroot(v,u);
size[u]+=size[v],gmax(f[u],size[v]);
}
}
gmax(f[u],Count-size[u]);
if(f[root]>f[u]) root=u;
}
int ans;
struct node{
int first,second;
}st[N];
int deep[N],dis[N],deep_mx,tmp[N],mx[N];
void getdeep(int u,int fa,int dep,int d){
deep[u]=dep+Color[u],dis[u]=d;
gmax(deep_mx,deep[u]);
size[u]=1;
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
if(!Del[v]&&v!=fa){
getdeep(v,u,deep[u],d+e[i].w);
size[u]+=size[v];
}
}
}
void getmx(int u,int pre){
gmax(tmp[deep[u]],dis[u]);
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
if(!Del[v]&&v!=pre) getmx(v,u);
}
}
bool cmp(const node &u,const node &v){
if(u.first!=v.first) return u.first<v.first;
return u.second<v.second;
}
void work(int u){
Del[u]=true;
int cnt=0;
K-=Color[u];
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
if(!Del[v]){
deep_mx=0;
getdeep(v,u,0,e[i].w);
st[cnt++]=(node){deep_mx,v};
}
}
sort(st,st+cnt,cmp);
for(int i=0;i<cnt;i++){
getmx(st[i].second,u);
int now=0;
if(i!=0)
for(int j=st[i].first;j>=0;j--){
while(now+j<K&&now<st[i-1].first)
++now,gmax(mx[now],mx[now-1]);
if(now+j<=K) gmax(ans,mx[now]+tmp[j]);
}
if(i!=cnt-1)
for(int j=st[i].first;j>=0;j--)
gmax(mx[j],tmp[j]),tmp[j]=0;
else
for(int j=st[i].first;j>=0;j--){
gmax(mx[j],tmp[j]);
if(j<=K) gmax(ans,mx[j]);
tmp[j]=mx[j]=0;
}
}
K+=Color[u];
for(int i=head[u];i!=-1;i=e[i].next){
int v=e[i].to;
if(!Del[v]){
Count=f[0]=size[v];
getroot(v,root=0);
work(root);
}
}
}
int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int main(){
int n,m;
n=read(),K=read(),m=read();
tot=0;
memset(head,-1,sizeof(head));
int u,v,w;
for(int i=1;i<=m;i++)
u=read(),Color[u]=1;
for(int i=1;i<n;i++){
u=read(),v=read(),w=read();
addedge(u,v,w),addedge(v,u,w);
}
Count=f[0]=n;
getroot(1,root=0);
ans=0;
work(root);
printf("%d\n",ans);
return 0;
}