题意:给一棵树,m个黑点和一个K,问最长的不超过k个黑点的带权路径。
分析:树上的点分治,被卡了一天,因为宏定义出现了神奇的错误。
#include<iostream>
#include<string>
#include<algorithm>
#include<cstdlib>
#include<cstdio>
#include<set>
#include<map>
#include<vector>
#include<cstring>
#include<stack>
#include<queue>
#define INF 0x3f3f3f3f
#define eps 1e-9
#define MOD 1000000007
#define MAXN 200005
using namespace std;
int n,k,m,x,y,val,ans,tot_edge,root,a[MAXN],e[2*MAXN][3],temp[MAXN],f[MAXN],dep[MAXN],b[MAXN],Size[MAXN],size[MAXN];
bool crowd[MAXN],jud[MAXN];
void Insert(int x,int y,int val)
{
e[++tot_edge][0] = y;
e[tot_edge][1] = val;
e[tot_edge][2] = a[x];
a[x] = tot_edge;
}
bool camp(int a,int b)
{
return dep[e[a][0]] < dep[e[b][0]];
}
int dfs(int u,int fa,int tot)
{
bool flag = true;
size[u] = 1;
for(int i = a[u];i;i = e[i][2])
{
int v = e[i][0];
if(v != fa && !jud[v])
{
int tmp = dfs(v,u,tot);
if(tmp) return tmp;
size[u] += size[v];
if(size[v] > tot/2) flag = false;
}
}
if(flag && tot - size[u] <= tot/2) return u;
else return 0;
}
int dfs2(int u,int fa,int deep,int fx)
{
int num = 1;
if(deep <= k)
{
ans = max(ans,fx);
b[deep] = max(fx,b[deep]);
}
for(int i = a[u];i;i = e[i][2])
{
int v = e[i][0],val = e[i][1];
if(!jud[v] && v != fa) num += dfs2(v,u,deep + crowd[v],fx + val);
}
return num;
}
int dfs_deep(int u,int fa,int fx)
{
if(fx == k) return k;
int tmp = fx;
for(int i = a[u];i;i = e[i][2])
{
int v = e[i][0];
if(v != fa && !jud[v])
tmp = max(tmp,dfs_deep(v,u,fx + crowd[v]));
}
return tmp;
}
void deal(int u)
{
int size = 0,son = 0;
for(int i = a[u];i;i = e[i][2])
{
int v = e[i][0];
if(!jud[v])
{
dep[v] = dfs_deep(v,u,crowd[u] + crowd[v]);
size = max(size,dep[v]);
temp[++son] = i;
}
}
sort(temp+1,temp+1+son,camp);
memset(f,0,sizeof(int)*(size + 2));
for(int i = 1;i <= son;i++)
{
int v = e[temp[i]][0],val = e[temp[i]][1];
memset(b,0,sizeof(int)*(dep[v]+2));
Size[v] = dfs2(v,u,crowd[u] + crowd[v],val);
for(int j = crowd[u] + crowd[v];j <= dep[v];j++)
ans = max(ans,f[min(k - j + crowd[u],dep[e[temp[i-1]][0]])] + b[j]);
for(int j = 0;j <= dep[v];j++)
{
if(j) f[j] = max(f[j],f[j-1]);
f[j] = max(f[j],b[j]);
}
}
}
void calc(int u)
{
deal(u);
jud[u] = true;
for(int i = a[u];i;i = e[i][2])
{
int v = e[i][0];
if(!jud[v]) calc(dfs(v,u,Size[v]));
}
}
int main()
{
ans = 0;
scanf("%d%d%d",&n,&k,&m);
for(int i = 1;i <= m;i++)
{
scanf("%d",&x);
crowd[x] = true;
}
for(int i = 1;i < n;i++)
{
scanf("%d%d%d",&x,&y,&val);
Insert(x,y,val);
Insert(y,x,val);
}
root = dfs(1,0,n);
calc(root);
cout<<ans<<endl;
}