Code:
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
#define setIO(s) freopen(s".in","r",stdin)
#define maxn 200004
#define mod 998244353
#define ll long long
using namespace std;
inline ll add(ll a,ll b) {
return (a+b)%mod;
}
inline ll de(ll a,ll b) {
return ((a%mod)-(b%mod)+mod)%mod;
}
int n,edges;
int hd[maxn],to[maxn],nex[maxn],dep[maxn],f[maxn],siz[maxn],son[maxn],top[maxn],p[maxn],sz[maxn];
ll inv[maxn],mul,g[maxn],ans;
struct Node {
int u,v;
Node(int u=0,int v=0):u(u),v(v){}
};
vector<Node>vi[maxn];
vector<int>G[maxn];
ll qpow(ll base,ll k) {
ll re=1;
while(k) {
if(k&1) re=re*base%mod;
base=base*base%mod;
k>>=1;
}
return re;
}
void add(int u,int v) {
nex[++edges]=hd[u],hd[u]=edges,to[edges]=v;
}
void dfs1(int u) {
dep[u]=dep[f[u]]+1,siz[u]=1,G[dep[u]].push_back(u);
for(int i=hd[u];i;i=nex[i]) {
int v=to[i];
if(v==f[u]) continue;
dfs1(v),siz[u]+=siz[v];
if(siz[v]>siz[son[u]]) son[u]=v;
}
}
void dfs2(int u,int tp) {
top[u]=tp;
if(son[u]) dfs2(son[u],tp);
for(int i=hd[u];i;i=nex[i]) {
int v=to[i];
if(v==f[u]||v==son[u]) continue;
dfs2(v,v);
}
}
int LCA(int x,int y) {
while(top[x]^top[y]) {
dep[top[x]]>dep[top[y]]?x=f[top[x]]:y=f[top[y]];
}
return dep[x]<dep[y]?x:y;
}
void init() {
for(int i=0;i<maxn;++i) p[i]=i,sz[i]=1;
}
int find(int x) {
return p[x]==x?x:p[x]=find(p[x]);
}
void merge(int x,int y) {
int fx=find(x),fy=find(y);
if(fx!=fy) {
mul=mul*inv[sz[fx]+1]%mod*inv[sz[fy]+1]%mod;
p[fx]=fy, sz[fy]+=sz[fx];
mul=mul*(sz[fy]+1)%mod;
}
}
void Initialize() {
init(),inv[1]=1,mul=qpow(2,n);
for(int i=2;i<maxn;++i) inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod;
}
int main() {
// setIO("input");
scanf("%d",&n);
for(int i=2;i<=n;++i) scanf("%d",&f[i]), add(f[i],i);
Initialize(),dfs1(1),dfs2(1,1);
for(int i=2;i<=n;++i) vi[dep[i]-1].push_back(Node(i,f[i]));
for(int i=1;i<=n;++i) {
for(int j=1;j<G[i].size();++j) {
vi[dep[G[i][j]] - dep[LCA(G[i][j], G[i][j-1])]].push_back(Node(G[i][j], G[i][j-1]));
}
}
ans=de(mul,n+1);
for(int i=1;i<=n;++i) {
for(int j=0;j<vi[i].size();++j)
merge(vi[i][j].u, vi[i][j].v);
ans=add(ans,de(mul,n+1));
}
printf("%lld\n",ans);
return 0;
}