题意:
在一棵根节点为1的树上,一开始所有的路都是坏的,现在你可以修路。
让你找出以每个节点为首都,到达其他任意一个节点所经过的坏路不超过1条的方案数。
思路:
参考别人的。参考完之后又感觉思想狠简单。。。
dp1[i]:以i为根的子树,满足到达这棵子树的任意节点的坏路不超过1条的方案数。
dp2[i]:从该点出发,往父亲方向的满足要求的方案数。
首先先看dp1。假设当前节点为u,v为其子节点,则u~v之间的路只有修和不修两种。
修:方案数有dp1[v]; 不修:方案数只有1,下面的路都是要修的。
则 dp1[u] = (dp1[v1]+1)*(dp1[v2]+1)*……
转换根节点,现在看dp2,假设要求节点dp2[v1],u为v的父亲节点。(dp2[1] = 1)
v1的方案数来自父亲节点往上的方案数还有其他兄弟的方案数。
则dp2[v1] = (dp2[u]*(dp1[v2]+1)*(dp1[v3]+1)*……)+1;(加1的原因跟上面一样)
如果不能理解,就画个图感受一下。
1.有的同学可能会想到用乘法逆元来求其他兄弟的方案数,实际上这会WA10;
2.求出每个节点的前缀积和后缀积即可避免除法(逆元),如果每次都重新求兄弟方案数积TLE39.
code:
#include <bits/stdc++.h>
using namespace std;
const int N = 2e5+5;
const int MOD = 1e9+7;
typedef long long LL;
int n;
int head[N], cnt = 0;
LL dp1[N], dp2[N];
int f[N];
vector <LL> front[N];
vector <LL> back[N];
vector <LL> tt[N];
struct Edge {
int v, next;
}e[N<<1];
void addEdge(int u, int v) {
e[cnt] = (Edge){v, head[u]};
head[u] = cnt++;
}
inline void cal(LL &a, LL b) {
a = a*b%MOD;
}
void dfs(int u, int par) {
dp1[u] = 1;
f[u] = par;
int cnt = 0;
for(int i = head[u];i != -1; i = e[i].next) {
if(e[i].v == par) continue;
int v = e[i].v;
dfs(v, u);
cal(dp1[u], dp1[v]+1);
tt[u].push_back(dp1[v]+1);
cnt++;
}
LL tmp = 1, tmp2 = 1;
for(int i = 0, j = cnt-1;i < cnt; i++, j--) {
cal(tmp, tt[u][i]);
cal(tmp2, tt[u][j]);
front[u].push_back(tmp);
back[u].push_back(tmp2);
}
/*
cout<<"u = "<<u<<endl;
for(int i = 0;i < cnt; i++) {
cout<<front[u][i]<<" ";
}
cout<<endl;
for(int i = 0;i < cnt; i++) {
cout<<back[u][i]<<" ";
}
cout<<endl;
*/
}
/*
LL just(int v, int u) {
LL ret = 1;
for(int i = head[u];i != -1; i = e[i].next) {
if(e[i].v == v || e[i].v == f[u]) continue;
int tv = e[i].v;
cal(ret, dp1[tv]+1);
}
cal(ret, dp2[u]);
ret++;
return ret;
}
*/
LL just(int v, int u, int idx) {
LL ret = 1;
if(front[u].size() > 0) {
if(idx > 0) cal(ret, front[u][idx-1]);
if(idx < (back[u].size()-1)) {
int real = back[u].size()-idx-2;
cal(ret, back[u][real]);
}
}
cal(ret, dp2[u]);
ret++;
return ret;
}
void dfs2(int u, int par) {
int idx = 0;
for(int i = head[u];i != -1; i = e[i].next) {
if(e[i].v == par) continue;
int v = e[i].v;
dp2[v] = just(v, u, idx++);
dfs2(v, u);
}
}
void solve() {
dfs(1, -1);
dp2[1] = 1;
dfs2(1, -1);
for(int i = 1;i <= n; i++)
printf("%I64d%c", dp2[i]*dp1[i]%MOD, i == n?'\n':' ');
}
int main() {
scanf("%d", &n);
memset(head, -1, sizeof(head));
for(int i = 2;i <= n; i++) {
int v;
scanf("%d", &v);
addEdge(v, i);
addEdge(i, v);
}
solve();
return 0;
}