vjudge题面传送门:https://cn.vjudge.net/problem/SPOJ-FTOUR2
题目大意:给出一棵n个节点的树,将其黑白染色。求经过不超过k个黑点的路径权值和的最大值。 n < = 200000 n<=200000 n<=200000。
题目分析:这题大概有三种做法。第一种是用
f
[
n
o
d
e
]
[
n
u
m
]
f[node][num]
f[node][num]表示节点node向下走,经过刚好num个黑点的最大权值,转移的时候显然有:
f
[
n
o
d
e
]
[
n
u
m
]
=
max
(
f
[
s
o
n
]
[
n
u
m
]
)
+
1
f[node][num]=\max(f[son][num])+1
f[node][num]=max(f[son][num])+1(node为白色)
f
[
n
o
d
e
]
[
n
u
m
]
=
max
(
f
[
s
o
n
]
[
n
u
m
−
1
]
)
+
1
f[node][num]=\max(f[son][num-1])+1
f[node][num]=max(f[son][num−1])+1(node为黑色)
将小的平衡树转移到大的,转移前先统计答案。计算贡献则在平衡树上维护区间最大值,区间加法和右移标记即可。时间为
O
(
n
log
2
(
n
)
)
O(n\log^2(n))
O(nlog2(n))。由于区间加法和右移标记都是个全局标记,因此可以记下来,然后将平衡树改成线段树。
第二种方法是点分治。先将每棵子树按链上黑色节点数的最大值从小到大排序,然后分别进行DFS。DFS的时候记录分治中心到该点的长度dis,黑色节点数num,用dis更新DP数组g[num]。记数组f[]为前几棵子树的DP数组,则用g更新f,并在途中求答案即可。更新完之后扫一遍f数组以保持其单调性。由于已经将子树排序,故f数组的长度不超过g数组,时间和g数组长度成正比,于是该次递归总时间和当前连通块大小成正比,时间复杂度为 O ( n log ( n ) ) O(n\log(n)) O(nlog(n))。如果不将子树排序,每次扫f数组的代价就会很高,无法保证时间,可能还要用一些数据结构来维护前缀最大值。另外,排序不会影响时间复杂度,因为我们只对分治中心的儿子排了序而不是对整个连通块,势能分析一下就知道排序的总时间也是 O ( n log ( n ) ) O(n\log(n)) O(nlog(n))的。
第三种方法是我一开始看到这题时想到的,同样是点分治,网上好像还没有人写。点分治的关键在于递归的时候要保证产生贡献的两个点不能来自同一棵子树。既然如此我们可以DFS连通块一遍,并记录该点来自哪棵子树,然后记f1[num]和f2[num]表示经过num个黑点的路径长度的最大和次大值,并记录下它们来自哪棵子树,强制要求这两棵子树不同。更新的时候分类讨论一下即可,时间依旧是 O ( n log ( n ) ) O(n\log(n)) O(nlog(n))。实际运行570ms,比上面那种方法快100ms。
我一开始先写了第二种方法,然后再写第三种方法。写第二种方法的时候debug了很久,大概就是以下这样:
第一次提交,自信满满,结果WA。
回来查错,重读题面,发现一开始ans要初始化为-oo。再交,WA。
再查错,发现Calc函数中num>k要返回k。再交,WA。
再查错,发现没有初始化g[0]~g[ col[root]+col[ Son[j] ]-1 ],还发现两个变量重名了。再交,WA。
再查错,改了句val=min(val,k)。然后上网看了一下别人怎么错的,发现题面有误,一开始ans要初始化为0。再交,WA。
最后肉眼查不出错,直接对拍,发现链的一端为根的时候会有问题。
再拍,发现col[root]+col[ Son[j] ]>k的时候,会令nf变为-1。
再拍,发现要单独算g数组的贡献。
再拍,终于发现最主要的问题。其实根本原因是我没有将对应下标与nf取min。
终于AC。
原计划15min调试完,结果花了1.5h……QAQ
总结:考场上写点分,一定要对拍。
注意一下写第三种方法的时候,f数组的大小是连通块大小和k取min。
CODE(第二种方法):
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;
const int maxn=200100;
const int oo=2100000000;
struct edge
{
int obj,len;
edge *Next;
} e[maxn<<1];
edge *head[maxn];
int cur=-1;
int tree[maxn];
int Size[maxn];
int max_Size[maxn];
int cnt;
int f[maxn];
int g[maxn];
int nf,ng;
bool vis[maxn];
int max_num[maxn];
int dis[maxn];
int Son[maxn];
int col[maxn];
int n,k,m;
int ans=0;
void Add(int x,int y,int z)
{
cur++;
e[cur].obj=y;
e[cur].len=z;
e[cur].Next=head[x];
head[x]=e+cur;
}
void Dfs(int node,int fa)
{
tree[++cnt]=node;
Size[node]=1;
max_Size[node]=0;
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
if ( son==fa || vis[son] ) continue;
Dfs(son,node);
Size[node]+=Size[son];
max_Size[node]=max(max_Size[node],Size[son]);
}
}
int Calc(int node,int fa,int num)
{
int val=num;
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
if ( son==fa || vis[son] ) continue;
val=max(val,Calc(son,node,num+col[son]));
}
val=min(val,k);
return val;
}
bool Comp(int x,int y)
{
return max_num[x]<max_num[y];
}
void Work(int node,int fa,int num,int Dis)
{
if (num>k) return;
if (num<=ng) g[num]=max(g[num],Dis);
else ng=num,g[num]=Dis;
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
if ( son==fa || vis[son] ) continue;
Work(son,node,num+col[son],Dis+p->len);
}
}
void Solve(int node)
{
cnt=0;
Dfs(node,node);
if (cnt==1) return;
int root=tree[1];
for (int i=1; i<=cnt; i++)
{
int now=tree[i];
max_Size[now]=max(max_Size[now],Size[node]-Size[now]);
if (max_Size[now]<max_Size[root]) root=now;
}
cnt=0;
for (edge *p=head[root]; p; p=p->Next)
{
int son=p->obj;
if (vis[son]) continue;
Son[++cnt]=son;
max_num[son]=Calc(son,root,col[son]);
dis[son]=p->len;
}
sort(Son+1,Son+cnt+1,Comp);
nf=col[root];
for (int i=0; i<col[root]; i++) f[i]=-oo;
f[nf]=0;
for (int j=1; j<=cnt; j++)
{
if (col[root]+col[ Son[j] ]>k) continue;
ng=-1;
for (int i=0; i<col[root]+col[ Son[j] ]; i++) g[i]=-oo;
Work(Son[j],root,col[root]+col[ Son[j] ],dis[ Son[j] ]);
for (int i=0; i<=ng; i++) ans=max(ans,f[ min(nf,k-i+col[root]) ]+g[i]);
for (int i=0; i<=nf; i++) f[i]=max(f[i],g[i]);
for (int i=nf+1; i<=ng; i++) f[i]=g[i];
nf=ng;
for (int i=1; i<=nf; i++) f[i]=max(f[i],f[i-1]);
}
vis[root]=true;
for (edge *p=head[root]; p; p=p->Next)
{
int son=p->obj;
if (!vis[son]) Solve(son);
}
}
int main()
{
freopen("c.in","r",stdin);
freopen("c.out","w",stdout);
scanf("%d%d%d",&n,&k,&m);
for (int i=1; i<=m; i++)
{
int x;
scanf("%d",&x);
col[x]=1;
}
for (int i=1; i<=n; i++) head[i]=NULL;
for (int i=1; i<n; i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
Add(x,y,z);
Add(y,x,z);
}
Solve(1);
printf("%d\n",ans);
return 0;
}
CODE(第三种方法):
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;
const int maxn=200100;
const int oo=2100000000;
struct edge
{
int obj,len;
edge *Next;
} e[maxn<<1];
edge *head[maxn];
int cur=-1;
int tree[maxn];
int Size[maxn];
int max_Size[maxn];
int cnt;
int f1[maxn];
int f2[maxn];
int id1[maxn];
int id2[maxn];
bool vis[maxn];
int col[maxn];
int n,k,m;
long long ans=0;
void Add(int x,int y,int z)
{
cur++;
e[cur].obj=y;
e[cur].len=z;
e[cur].Next=head[x];
head[x]=e+cur;
}
void Dfs(int node,int fa)
{
tree[++cnt]=node;
Size[node]=1;
max_Size[node]=0;
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
if ( son==fa || vis[son] ) continue;
Dfs(son,node);
Size[node]+=Size[son];
max_Size[node]=max(max_Size[node],Size[son]);
}
}
void Update(int dis,int id,int num)
{
if (dis>f1[num])
if (id!=id1[num])
{
f2[num]=f1[num];
id2[num]=id1[num];
f1[num]=dis;
id1[num]=id;
}
else f1[num]=dis;
else
if ( dis>f2[num] && id!=id1[num] )
{
f2[num]=dis;
id2[num]=id;
}
}
void Work(int node,int fa,int num,int dis,int id)
{
if (num>k) return;
Update(dis,id,num);
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
if ( vis[son] || son==fa ) continue;
Work(son,node,num+col[son],dis+p->len,id);
}
}
void Solve(int node)
{
cnt=0;
Dfs(node,node);
if (cnt==1) return;
int root=tree[1];
for (int i=1; i<=cnt; i++)
{
int now=tree[i];
max_Size[now]=max(max_Size[now],Size[node]-Size[now]);
if (max_Size[now]<max_Size[root]) root=now;
}
cnt=min(cnt,k);
for (int i=0; i<=cnt; i++) f1[i]=f2[i]=-oo,id1[i]=id2[i]=0;
f1[ col[root] ]=0;
for (edge *p=head[root]; p; p=p->Next)
{
int son=p->obj;
if (vis[son]) continue;
Work(son,root,col[root]+col[son],p->len,son);
}
for (int i=0; i<cnt; i++) Update(f1[i],id1[i],i+1),Update(f2[i],id2[i],i+1);
for (int i=0; i<=cnt; i++)
{
int j=k-i+col[root];
j=min(j,cnt);
if (id1[i]!=id1[j]) ans=max(ans,(long long)f1[i]+f1[j]);
else ans=max(ans,(long long)f1[i]+f2[j]);
}
vis[root]=true;
for (edge *p=head[root]; p; p=p->Next)
{
int son=p->obj;
if (!vis[son]) Solve(son);
}
}
int main()
{
freopen("c.in","r",stdin);
freopen("c.out","w",stdout);
scanf("%d%d%d",&n,&k,&m);
for (int i=1; i<=m; i++)
{
int x;
scanf("%d",&x);
col[x]=1;
}
for (int i=1; i<=n; i++) head[i]=NULL;
for (int i=1; i<n; i++)
{
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
Add(x,y,z);
Add(y,x,z);
}
Solve(1);
printf("%d\n",ans);
return 0;
}