此题无链接
题目描述
前言
我TM拿肾肝!(大きい玉螺旋丸!)
由于题解的“根据目前情况”几个字省略的关键信息和过多细节,导致我拿肾肝了一个下午都没肝出来。
题解
贪心。容易发现,由于保留某个节点意味着祖先节点都要保留,所以令先序遍历最优或中序遍历最优是一样的,不妨按先序遍历最优来贪心。
考虑先序遍历枚举某个点是否可以保留。从当前点向上,每当遇到自己是左儿子时,根据目前情况计算右子树至少要留下多少节点。这样可以统计出需要留下整棵树至少多大,如果 ≤ k \le k ≤k 则可以保留。
怎么根据目前情况计算呢?首先我们能得到的肯定只有某个子树的最大深度至少为多少,所以需要预先做一个DP求得深度为某个值的AVL树至少有多少个节点:
d
p
[
i
]
=
d
p
[
i
−
1
]
+
d
p
[
i
−
2
]
+
1
dp[i]=dp[i-1]+dp[i-2]+1
dp[i]=dp[i−1]+dp[i−2]+1
然后怎么确定这个最小深度呢?
如果你仅仅根据当前左子树的深度来确定右子树的深度,那么可以喜提 26 26 26 分。
这样做有明显的反例:
假设通过之前的遍历已经限定这个右子树深度至少为4,那么显然只有保留圈中的节点是最优的,而当前剩下需要保留的节点数正好是7,
然而在检查是否保留圈红的这个节点时,向上枚举到这个子树的树根,会认为右子树深度至少为1,只需保留1个节点,从而判定可以把自己保留。这样显然会导致遍历到最深的那个必须保留的节点时,因为节点数不够用而抛弃它。
所以我们对每个右子树确定的深度下限,必须还要参考前面的限制。我们可以给每个节点设置一个参数 m h mh mh( m a x h e i g h t max\,height maxheight,用 m d md md 不舒服)表示这棵子树的最大深度至少为多少,那么每次检查确定右子树大小时,要把这棵右子树的当前的 m h mh mh 值也考虑上。
m h mh mh 值是需要下传的。如果左子树的最大深度可以满足要求,显然应该传给左子树,否则只能传给右子树(别忘了同时把 m h − 1 mh-1 mh−1 下传给另一个子树)。一个节点必须经过祖先节点下传 m h mh mh 值后才能开始检查是否保留,而由于我们是先序遍历,正好可以边遍历边下传。
代码
#include<cstdio>//JZM yyds!!
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<ctime>
#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
#define ll long long
#define uns unsigned
#define MOD 1000000007ll
#define MAXN 500005
#define INF 1e18
#define lowbit(x) (x&(-x))
#define IF it->first
#define IS it->second
using namespace std;
inline ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0)f^=(s=='-'),s=getchar();
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+(s^48),s=getchar();
return f?x:-x;
}
int n,k,root;
int fa[MAXN],ls[MAXN],rs[MAXN];
int ph[MAXN],mh[MAXN];
int h[MAXN],dep[MAXN],dp[45];
bool as[MAXN];
inline void dfs(int x){
dep[x]=dep[fa[x]]+1;
ph[x]=dep[x];
if(ls[x])dfs(ls[x]),ph[x]=max(ph[x],ph[ls[x]]);
if(rs[x])dfs(rs[x]),ph[x]=max(ph[x],ph[rs[x]]);
}
inline int ck(int x){
int y=max(dep[x],h[x]),num=0;
while(x){
if(!as[x])num++;
y=max(y,h[x]);
if(x<fa[x]&&!as[rs[fa[x]]])
num+=dp[max(y-1,mh[rs[fa[x]]])-dep[fa[x]]];
x=fa[x];
}
return num;
}
inline void add(int x){
h[x]=max(h[x],dep[x]);
int y=h[x];
while(x){
h[x]=max(h[x],y);
if(!as[x])as[x]=1,k--;
if(x<fa[x]&&!as[rs[fa[x]]]){
int v=rs[fa[x]];
if(v)mh[v]=max(mh[v],h[x]-1);
}
x=fa[x];
}
}
inline void solve(int x){
if(ck(x)<=k)add(x);
if(ls[x]&&rs[x]){
if(ph[ls[x]]<mh[x])
mh[rs[x]]=max(mh[rs[x]],mh[x]),
mh[ls[x]]=max(mh[ls[x]],mh[x]-1);
else mh[ls[x]]=max(mh[ls[x]],mh[x]),
mh[rs[x]]=max(mh[rs[x]],mh[x]-1);
solve(ls[x]),solve(rs[x]);
}else{
if(ls[x]){
mh[ls[x]]=max(mh[ls[x]],mh[x]);
solve(ls[x]);
}
if(rs[x]){
mh[rs[x]]=max(mh[rs[x]],mh[x]);
solve(rs[x]);
}
}
}
signed main()
{
freopen("avl.in","r",stdin);
freopen("avl.out","w",stdout);
dp[0]=0,dp[1]=1;
for(int i=2;i<=40;i++)dp[i]=dp[i-1]+dp[i-2]+1;
n=read(),k=read();
for(int i=1;i<=n;i++){
fa[i]=read();
if(fa[i]<0)fa[i]=0,root=i;
else if(fa[i]<i)rs[fa[i]]=i;
else ls[fa[i]]=i;
}
dfs(root),solve(root);
for(int i=1;i<=n;i++)putchar(as[i]?'1':'0');
putchar('\n');
return 0;
}