题意:根据那个伪代码去涂色,只要是有可能的都可以到达的,问你有多少种不同颜色的树。
因为只有2种颜色,所以很自然可以想到用奇偶性,dp[u][0]代表以u为根的树,节点(包括根)总数为偶数的方案数,dp[u][1]就是奇数了。对于子节点分2种情况遍历,一种是从左往右,一种是从右往左,用组合数的方法可以发觉这2种答案是一样的,所以只需要算一次乘2就行,但是如何去重就是个麻烦的地方。
2种情况会有重复,一种是遍历的几个子树节点数全是偶数,这样从左往右进去是白色,最后涂成的颜色和从右往左进去是白色是一样的。
类似,奇数个节点数是奇数的子树也是会有同样效果,但是这2者不能结合。
还有,因为这题的输入是第i个节点的根,根一定在前面,所以可以直接逆推不用DFS,这样当遍历到某个点,他下面的dp一定已经推出来了。
AC代码:
//#pragma comment(linker, "/STACK:102400000,102400000")
#include<cstdio>
#include<ctype.h>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<vector>
#include<cstdlib>
#include<stack>
#include<cmath>
#include<queue>
#include<set>
#include<map>
#include<ctime>
#include<string.h>
#include<string>
#include<sstream>
#include<bitset>
using namespace std;
#define ll long long
#define ull unsigned long long
#define eps 1e-4
#define NMAX 200005
#define MOD 1000000007
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define PI acos(-1)
template<class T>
inline void scan_d(T &ret)
{
char c;
int flag = 0;
ret=0;
while(((c=getchar())<'0'||c>'9')&&c!='-');
if(c == '-')
{
flag = 1;
c = getchar();
}
while(c>='0'&&c<='9') ret=ret*10+(c-'0'),c=getchar();
if(flag) ret = -ret;
}
const int maxn = 100000+10;
vector<int>G[maxn];
ll dp[maxn][2];//dp是包括根节点和子节点的奇偶的方案个数
void solve(int n)
{
for(int u = n; u >= 1; u--)
{
dp[u][1] = 1;
dp[u][0] = 0;
int sz = G[u].size();
if(!sz) continue;
for(int i = 0; i < sz; i++)
{
int v = G[u][i];
ll a = (dp[u][0]*dp[v][0]%MOD + dp[u][1]*dp[v][1]%MOD)%MOD;
ll b = (dp[u][1]*dp[v][0]%MOD + dp[u][0]*dp[v][1]%MOD)%MOD;
dp[u][0] = (dp[u][0]+a)%MOD;
dp[u][1] = (dp[u][1]+b)%MOD;
}
dp[u][0] = (dp[u][0]*2LL)%MOD;
dp[u][1] = (dp[u][1]*2LL)%MOD;
ll p[3]= {1,1,0},tmp;
for(int i = 0; i < sz; i++)
{
int v = G[u][i];
p[0] = ((p[0]*dp[v][0])%MOD+p[0])%MOD;
tmp = p[2];
p[2] = ((p[1]*dp[v][1])%MOD+p[2])%MOD;
p[1] = ((tmp*dp[v][1])%MOD+p[1])%MOD;
}
dp[u][1] = (dp[u][1]-p[0]+MOD)%MOD;//子树要是偶数
dp[u][0] = (dp[u][0]-p[2]+MOD)%MOD;//子树要是奇数
}
}
int main()
{
#ifdef GLQ
freopen("input.txt","r",stdin);
// freopen("o4.txt","w",stdout);
#endif // GLQ
int n;
scanf("%d",&n);
for(int i = 2; i <= n; i++)
{
int p;
scanf("%d",&p);
G[p].push_back(i);
}
solve(n);
printf("%I64d\n",(dp[1][0]+dp[1][1])%MOD);
return 0;
}