You are given a rooted tree with nn nodes, labeled from 11 to nn. The tree is rooted at node 11. The parent of the ii-th node is pipi. A leaf is node with no children. For a given set of leaves LL, let f(L)f(L) denote the smallest connected subgraph that contains all leaves LL.
You would like to partition the leaves such that for any two different sets x,yx,y of the partition, f(x)f(x) and f(y)f(y) are disjoint.
Count the number of ways to partition the leaves, modulo 998244353998244353. Two ways are different if there are two leaves such that they are in the same set in one way but in different sets in the other.
Input
The first line contains an integer nn (2≤n≤200000) — the number of nodes in the tree.
The next line contains n−1n−1 integers p2,p3,…,pn (1≤pi<i).
Output
Print a single integer, the number of ways to partition the leaves, modulo 998244353998244353.
Examples
input
Copy
5 1 1 1 1
output
Copy
12
input
Copy
10 1 2 3 4 5 6 7 8 9
output
Copy
1
Note
In the first example, the leaf nodes are 2,3,4,5. The ways to partition the leaves are in the following image
In the second example, the only leaf is node 10 so there is only one partition. Note that node 1 is not a leaf.
思路 对于一个结点有两种状态 和它的父节点连接或者和它的父节点不连接 分别用 dp[i][1] 和 dp[i][0]表示 我们这样考虑先求出总的方案数然后用总方案数减去那些不合法的方案数 那么一个父节点的贡献就是它的子结点的所有方案的积
1.如果父节点与它的父节点连接的话那么对于合法的情况它至少和它的一个子节点连接 所以不合法的情况就是它和它的所有子节点都不连接
2.如果父节点与它的父节点不连接的话那么对于合法的情况它不能只和它的一个子节点连接 所以不合法的情况就是它只和它的一个子节点连接
对于叶子节点要特判
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 200010,MAX = 0x3f3f3f3f,mod = 998244353;
ll dp[MAXN][2];
inline ll ksm ( ll a, ll b ) { ll res = 1; while ( b ) { if ( b & 1 ) res = res * a % mod; a = a * a % mod; b >>= 1; } return res; }
inline ll inv ( ll a ) { return ksm(a, mod - 2); }
inline ll add ( ll a, ll b ) { return (a + b) % mod; }
inline ll sub ( ll a, ll b ) { return ((a - b) % mod + mod) % mod; }
int main()
{
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
int n,m;
cin >> n;
vector<int> v[n+1];
for(int i = 2;i <= n; i++){
cin >> m;
v[m].push_back(i);
}
function<void(int,int)> dfs = [&] (int now,int fath){
ll s0 = 1,s1 = 1,s2 = 0;
for(int x : v[now]){
if(x == fath) continue;
dfs(x,now);
s0 = (s0*(dp[x][0]+dp[x][1]))%mod;
s1 = (s1*dp[x][0])%mod;
}
for(int x : v[now]){
s2 = (s2+s1*inv(dp[x][0])%mod*dp[x][1])%mod;
}
dp[now][0] = sub(s0,s2);
dp[now][1] = sub(s0,s1);
if(v[now].size() == 0) dp[now][1] = 1;
};
dfs(1,0);
cout << dp[1][0];
return 0;
}