https://codeforces.com/problemset/problem/1153/D
思路:参考issue
#include<iostream>
#include<vector>
#include<queue>
#include<cstring>
#include<cmath>
#include<map>
#include<set>
#include<cstdio>
#include<algorithm>
#define debug(a) cout<<#a<<"="<<a<<endl;
using namespace std;
const int maxn=3e5+1000;
typedef int LL;
inline LL read(){LL x=0,f=1;char ch=getchar(); while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;}
LL dp[maxn],c[maxn],out[maxn];///dp[i]:到当前i这个节点要花费的最小的>=mid的个数
vector<LL>g[maxn];
void dfs(LL u,LL mid){
dp[u]=0;///初始化
if(g[u].size()==0){
dp[u]=1;return;
}
LL minv=1e9;
for(LL i=0;i<g[u].size();i++){
LL v=g[u][i];
dfs(v,mid);
if(c[u]==1){
minv=min(minv,dp[v]);
}
else if(c[u]==0){
dp[u]+=dp[v];
}
}
if(c[u]) dp[u]=minv;
}
int main(void){
cin.tie(0);std::ios::sync_with_stdio(false);
LL n;cin>>n;
for(LL i=1;i<=n;i++) cin>>c[i];
for(LL i=2;i<=n;i++){
LL fa;cin>>fa;
g[fa].push_back(i);
out[fa]++;
}
LL k=0;
for(LL i=1;i<=n;i++) if(out[i]==0) k++;
LL l=1;LL r=n;
while(l<r){
LL mid=(l+r+1)>>1;
///memset(dp,0x3f,sizeof(dp));这样会在c[u]=0的时候dp[u]+=dp[v]有问题
dfs(1,mid);
if(k-mid+1>=dp[1]) l=mid;
else r=mid-1;
}
cout<<l<<"\n";
return 0;
}