Code:
#include <cstdio>
#include <algorithm>
#define N 200005
#define mod 998244353
#define ll long long
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
ll qpow(ll base,ll k)
{
ll tmp=1;
for(;k;base=base*base%mod,k>>=1) if(k&1) tmp=tmp*base%mod;
return tmp;
}
ll inv(ll k)
{
return qpow(k,mod-2);
}
int n,edges;
ll f[N],g[N];
int hd[N],to[N],nex[N],size[N];
void add(int u,int v)
{
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
void dfs(int u)
{
int i;
size[u]=1;
ll sum=1;
f[u]=g[u]=1;
for(i=hd[u];i;i=nex[i])
{
int v=to[i];
dfs(v);
size[u]+=size[v];
sum=sum*f[v]%mod;
f[u]=f[u]*((f[v]+g[v])%mod)%mod;
}
g[u]=f[u];
if(size[u]>1)
{
g[u]=(g[u]+mod-sum)%mod;
for(i=hd[u];i;i=nex[i])
{
int v=to[i];
ll tmp=inv(f[v])*g[v]%mod;
tmp=tmp*sum%mod;
f[u]=(f[u]+mod-tmp)%mod;
}
}
}
int main()
{
int i,j;
// setIO("input");
scanf("%d",&n);
for(i=2;i<=n;++i)
{
int a;
scanf("%d",&a),add(a,i);
}
dfs(1);
printf("%lld\n",f[1]);
return 0;
}