现在感觉还是挺好想的.
Code:
#include <bits/stdc++.h>
#define N 2000006
#define ll long long
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
vector<int>v[N];
int n,m,edges,ans,val[N],son[N],hd[N],to[N],nex[N];
void add(int u,int v)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
bool cmp(int i,int j)
{
return val[i]<val[j];
}
void dfs(int x)
{
for(int i=hd[x];i;i=nex[i]) dfs(to[i]), v[x].push_back(to[i]);
sort(v[x].begin(),v[x].end(),cmp);
val[x]+=v[x].size();
for(int i=0;i<v[x].size();++i)
{
int cur=v[x][i];
if(val[cur]+val[x]-1<=m)
{
val[x]+=val[cur]-1;
++ans;
}
else
{
break;
}
}
}
int main()
{
int i,j;
// setIO("input");
scanf("%d%d",&n,&m);
for(i=1;i<=n;++i) scanf("%d",&val[i]);
for(i=1;i<=n;++i)
{
scanf("%d",&son[i]);
for(j=1;j<=son[i];++j)
{
int x;
scanf("%d",&x), add(i,x+1);
}
}
dfs(1);
printf("%d\n",ans);
return 0;
}