Free tour II
题意:有一棵N个节点的树,每个节点要么被染成黑色,要么被染成白色,要求找出一条权值和最长的路径,使得路径上的黑色点的个数不超过k;
在看《分治算法在树的路径问题上的应用》这篇论文时看到的题,就拿来做了,搞了一天,才做出来,论文中此题解法只是大致谈了一下思想;看了半天,只看了个一知半解,感觉无处下手, 尝试着写一篇博客,可能也将不太明白,就算是让自己加深下理解了;
此题是求符合题意的树上路径,那么,最好的方法应该是点分治;想到这里,题目才起步;
点分治模板就不再多解释了;下面只看对于一棵平衡树怎么做,因为无论这棵树是什么样子的,我们都可以通过点分治将其尽可能的转换成平衡树结构,这样搜完是log的复杂度;
要找一条路径,一定会有两个端点,而这条路径一定经过两端点的最近公共祖先;
计算根节点的每个子节点到叶子的路径中包含的黑点的最大值,用num[v]数组记录;
将每条路经对应的num值排序,再有两个数组:F[i], G[i];F[i]表示路径中最多有i个黑点时的最长路径,G[i]表示路径中有i个黑点是的路径,前者是黑点数不大于i,后者是黑点数准确的等于i;那么前一条路的F[i]+后一条路的G[j]就是路径总长;并且要保证i+j<=k;
#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
using namespace std;
const int maxn=2e5+10;
int N, M, K;
struct node{
int v, nxt, w;
}edge[maxn<<1];
int cnt, head[maxn];
void add(int u, int v, int w){
edge[cnt]=node{v, head[u], w};
head[u]=cnt++;
}
int r, min1, _size[maxn], allnode, vis[maxn];
void getroot(int u, int fa){
int max1=-1;
_size[u]=1;
for(int i=head[u]; i!=-1; i=edge[i].nxt){
int v=edge[i].v;
if(v==fa || vis[v]) continue;
getroot(v, u);
_size[u]+=_size[v];
max1=max(max1, _size[v]);
}
max1=max(max1, allnode-_size[u]);
if(max1<min1) min1=max1, r=u;
}
int num[maxn], G[maxn], F[maxn], col[maxn];
void getnum(int u, int fa){
num[u]=col[u];
for(int i=head[u]; i!=-1; i=edge[i].nxt){
int v=edge[i].v;
if(v==fa || vis[v]) continue;
getnum(v, u);
num[u]=max(num[u], num[v]+col[u]);
}
}
void getG(int u, int fa, int x, int val){
G[x]=max(G[x], val);
for(int i=head[u]; i!=-1; i=edge[i].nxt){
int v=edge[i].v, w=edge[i].w;
if(v==fa || vis[v]) continue;
getG(v, u, x+col[v], w+val);
}
}
struct Node{
int v, num, w;
}T[maxn];
bool cmp(Node a, Node b){
return a.num<b.num;
}
int ans;
void solve(int u){
vis[u]=1;
for(int i=head[u]; i!=-1; i=edge[i].nxt){
int v=edge[i].v;
if(vis[v]) continue;
min1=INF;
allnode=_size[v];
getroot(v, -1);
solve(r);
}
int tot=0;
for(int i=head[u]; i!=-1; i=edge[i].nxt){
int v=edge[i].v, w=edge[i].w;
if(vis[v]) continue;
getnum(v, -1);
T[tot++]=Node{v, num[v], w};
}
sort(T, T+tot, cmp);
int limit=K-col[u];
for(int i=0; i<=T[tot-1].num; i++) F[i]=-INF;
for(int i=0; i<tot; i++){
for(int j=0; j<=T[i].num; j++) G[j]=-INF;
int v=T[i].v, w=T[i].w;
getG(v, u, col[v], w);
if(i){
for(int j=0; j<=T[i].num&&j<=limit; j++){
int temp=min(T[i-1].num, limit-j);
if(F[temp]==-INF) continue;
ans=max(ans, F[temp]+G[j]);
}
}
for(int j=0; j<=T[i].num&&j<=limit; j++){
F[j]=max(F[j], G[j]);
if(j) F[j]=max(F[j], F[j-1]);
ans=max(ans, F[j]);
}
}
vis[u]=0;
}
void init(){
cnt=0;
memset(head, -1, sizeof(head));
memset(vis, 0, sizeof(vis));
memset(col, 0, sizeof(col));
}
int main(){
while(~scanf("%d%d%d", &N, &K, &M)){
init();
for(int i=0; i<M; i++){
int x;
scanf("%d", &x);
col[x]=1;
}
for(int i=1; i<N; i++){
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
add(u, v, w);
add(v, u, w);
}
ans=0;
min1=INF;
allnode=N;
getroot(1, -1);
solve(r);
printf("%d\n", ans);
}
return 0;
}