看了别人的题解。。
题意:
有N个顶点的树,节点间有权值, 节点分为黑点和白点。 找一条最长路径使得 路径上黑点数量不超过K个
F[ i , j ] 表示它的第 I 个子树中经过的不超过 J 个黑点的路径中,最长的一条的长度是多少,这样可以保证 F[ I , J ] 的递增性。要求出F[ I , J ] ,我们只要对所有子树进行一次DFS即可,复杂度是O( N )的。不过如果要保存这样的状态对于某些数据可能有些困难,因为数据范围太大了,我们可以通过以下方法来优化:我们要求的 F[ i , j ] 是把所有做过的子树全部保存起来,不过我们要用的只是对于每个 J 的最大值!所以我们可以根据这一点进行一个对空间的优化。把 F[ i , j ] 变成一维的 F[ i ] 表示当前已经计算过的子树中,经过黑点数不超过 i 个的路径中最长的长度是多少,对于当前所计算的子树,用 G[ i ] 表示当前子树中经过黑色点数严格为 i 个的路径中最长的路径的长度是多少。可以比较G[ i ] 和 F[ i ] 的大小来更新 F[ i ] ,依然可以保证 F[ i ] 的递增。
不过每次更新一次 F[ i ] 的时候的复杂度是 F[ i ] 和 G[ i ] 深度的最大值,对于某些数据可能达到O( n^2 ) ,所以在进行更新之前对子树进行一次排序,关键字是该子树中路径经过最多数量的黑色节点的数量。
我就是没有看懂怎么从n^2变成logn的,所以特别解释一下:
因为最多只有n个黑色的节点,所以枚举黑色节点的个数进行更新即可,即对于下标相同的F数组,不用区分,直接合并取最大值。
我都把别人的题解抄了一遍了,始终RE,不知道哪错了
#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<vector>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<algorithm>
using namespace std;
const int maxn=450010;
const int INF=1e8;
int g[maxn],mg[maxn];
int N,K,M;
int black[maxn];
struct Node
{
int v,next,w;
}edge[1100000];
int head[maxn],tot;
bool vis[maxn];
int sz[maxn];
int dep[maxn];
int num[1100000];
int maxv[maxn];
int Max,root;
int ans;
void init()
{
tot=ans=0;
memset(head,-1,sizeof(head));
memset(vis,0,sizeof(vis));
memset(black,0,sizeof(black));
}
void add_edge(int u,int v,int w)
{
edge[tot].v=v;
edge[tot].w=w;
edge[tot].next=head[u];
head[u]=tot++;
}
//处理子树的大小
void dfssz(int u,int f)
{
sz[u]=1;
maxv[u]=0;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v==f||vis[v])continue;
dfssz(v,u);
sz[u]+=sz[v];
if(sz[v]>maxv[u])maxv[u]=sz[v];
}
}
//找重心
void dfsroot(int r,int u,int f)
{
if(sz[r]-sz[u]>maxv[u])//sz[r]-sz[u]是u上面部分的树的尺寸,跟u的最大孩子比,找到最大孩子的最小差值节点
maxv[u]=sz[r]-sz[u];
if(maxv[u]<Max)Max=maxv[u],root=u;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v==f||vis[v])continue;
dfsroot(r,v,u);
}
}
void dfsdep(int u,int fa)
{
dep[u]=black[u];
int sum=0;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v==fa||vis[v])continue;
dfsdep(v,u);
sum=max(sum,dep[v]);
}
dep[u]+=sum;
}
bool cmp(int x,int y)
{
return dep[edge[x].v]<dep[edge[y].v];
}
void dfsg(int u,int fa,int d,int c)
{
g[c]=max(g[c],d);
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v==fa||vis[v])continue;
dfsg(v,u,d+edge[i].w,c+black[v]);
}
}
void dfs(int u,int fa)
{
Max=INF;
dfssz(u,fa);
dfsroot(u,u,fa);
int rt=root,cnt=0;
vis[rt]=true;
for(int i=head[rt];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(!vis[v])dfs(v,rt);
}
for(int i=head[rt];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(!vis[v])
{
dfsdep(v,rt);
num[++cnt]=i;
}
}
sort(num+1,num+1+cnt,cmp);
for(int i=0;i<=dep[edge[num[cnt]].v];i++)mg[i]=-INF;
for(int i=1;i<=cnt;i++)
{
int v=edge[num[i]].v,d=dep[v];
int val=edge[num[i]].w;
for(int j=0;j<=d;j++)g[j]=-INF;
dfsg(v,rt,val,black[v]);
if(i!=1)
{
for(int j=0;j<=K-black[rt]&&j<=d;j++)
{
int tmp=min(dep[edge[num[i-1]].v],K-black[rt]-j);
if(mg[tmp]==-INF)break;
if(g[j]!=-INF)ans=max(ans,mg[tmp]+g[j]);
}
}
for(int j=0;j<=d;j++)
{
mg[j]=max(mg[j],g[j]);
if(j)mg[j]=max(mg[j-1],mg[j]);
if(j+black[rt]<=K)ans=max(ans,mg[j]);
}
}
vis[rt]=false;//注意这个
}
int main()
{
int u,x,v,w;
scanf("%d%d%d",&N,&K,&M);
init();
for(int i=1;i<=M;i++)
{
scanf("%d",&x);
black[x]=1;
}
for(int i=1;i<N;i++)
{
scanf("%d%d%d",&u,&v,&w);
add_edge(u,v,w);
add_edge(v,u,w);
}
ans=0;
dfs(1,0);
printf("%d\n",ans);
return 0;
}