题意:给一个n节点的树,有k个叶子结点,给每个叶子结点附上1到k的值,每个节点可以是他所有的子节点的最大值(第二行输入为1)或者最小值(0);问根节点1的最大值是多少;
用vis[i]表示节点i要从子节点取最大值或者最小值;
思路:数组d[i]保存节点i要浪费多少数字,(比如节点i有3个叶子结点,且vis[i]=0,则节点i只能取第3大的那个儿子,即浪费了2个数字);
所以当vis[i]=1,d[i]=min(d[y])(y为所有i的儿子);
d[叶子结点]=0;
当vis[i]=0,d[i] = sum(d[y])+(i的儿子数量-1);
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 3e5 + 50;
int d[maxn];
int vis[maxn], fa[maxn];
vector<int> son[maxn];
void dfs(int x)
{
int len = son[x].size();
for(int i = 0; i < len; i++)
{
int y = son[x][i];
if(son[y].size() > 0) dfs(y);
}
if(vis[x])
{
int y = son[x][0]; d[x] = d[y];
for(int i = 1; i < len; i++)
{
y = son[x][i];
d[x] = min(d[x], d[y]);
}
}
else
{
d[x] = len - 1;
bool ok = false; int y;
for(int i = 0; i < len; i++)
{
y = son[x][i];
d[x] += d[y];
}
}
}
int main()
{
// freopen("in.txt", "r", stdin);
int n; scanf("%d", &n);
for(int i = 1; i <= n; i++)
scanf("%d", &vis[i]);
for(int i = 2; i <= n; i++)
{
scanf("%d", &fa[i]);
son[fa[i]].push_back(i);
}
dfs(1);
int sum = 0;
for(int i = 1; i <= n; i++)
if(son[i].size() == 0) sum++;
printf("%d\n", sum - d[1]);
return 0;
}