题意
DZY有一棵n个结点的无根树,结点按照 1∼n 标号。
DZY喜欢树上的连通集。一个连通集SS是由一些结点组成的集合,满足SS中任意两个结点u,v能够用树上的路径连通,且路径上不经过S之外的结点。显然,单独一个结点的集合也是连通集。
一个连通集的大小定义为它包含的结点个数,DZY想知道所有连通集的大小之和是多少。你能帮他数一数吗?
答案可能很大,请对 109+7 取模后输出。
思路
其他的就不说了,主要说说如果采用 dfs 方法的话,这个关键的递推式是什么意思。
ans[u] = ans[u] * (sum[v] + 1) + ans[v] * sum[u]
其中 u 是当前结点,v 是子结点。
ans[u] 表示 u 在 u 结点处产生的所有集合里的元素个数。
sum[u] 表示以 u 为根节点的集合个数。
所以答案是 ∑ni=1ans[i]
ans[u] 由两部分构成。
- u 结点处新产生的元素的贡献
- v 结点原来的元素的贡献
先说ans[u] * (sum[v] + 1)
。这里得出的就是第一部分的值。
其实这里应该写为ans[u]*sum[v] + ans[u]
因为每个元素都能在子结点的集合里出现一次,所以 ans[u]*sum[v]
就得出了增加了结点 v 之后又新增的贡献,还要加上原来的。
ans[v]*sum[u]
计算的是后一部分的贡献。
每个子结点的元素都能在 u 的集合里出现一次。
我也是马后炮推出来的。
代码
#include <stack>
#include <cstdio>
#include <list>
#include <cassert>
#include <set>
#include <iostream>
#include <string>
#include <sstream>
#include <vector>
#include <queue>
#include <functional>
#include <cstring>
#include <algorithm>
#include <cctype>
//#pragma comment(linker, "/STACK:102400000,102400000")
#include <string>
#include <map>
#include <cmath>
using namespace std;
#define LL long long
#define ULL unsigned long long
#define SZ(x) (int)x.size()
#define Lowbit(x) ((x) & (-x))
#define MP(a, b) make_pair(a, b)
#define MS(p, num) memset(p, num, sizeof(p))
#define PB push_back
#define X first
#define Y second
#define ROP freopen("input.txt", "r", stdin);
#define MID(a, b) (a + ((b - a) >> 1))
#define LC rt << 1, l, mid
#define RC rt << 1|1, mid + 1, r
#define LRT rt << 1
#define RRT rt << 1|1
#define FOR(i, a, b) for (int i=(a); (i) < (b); (i)++)
#define FOOR(i, a, b) for (int i = (a); (i)<=(b); (i)++)
const double PI = acos(-1.0);
const int INF = 0x3f3f3f3f;
const double eps = 1e-8;
const int MAXN = 2e5 + 10;
const int MOD = 1e9 + 7;
const int dir[][2] = { {-1, 0}, {1, 0}, {0, -1}, {0, 1} };
const int seed = 131;
int cases = 0;
typedef std::pair<int, int> pii;
LL dp[MAXN][2];
std::vector<int> G[MAXN];
void dfs(int u, int fa) {
for (int v : G[u]) if (v != fa) {
dfs(v, u);
dp[u][1] = (dp[u][1] * (dp[v][0] + 1)) % MOD + dp[u][0]*dp[v][1]%MOD;
dp[u][1] %= MOD;
dp[u][0] = dp[u][0] * (dp[v][0] + 1) % MOD;
}
}
int main() {
// ROP;
int T;
scanf("%d", &T);
while (T--) {
int n;
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
G[i].clear();
dp[i][0] = dp[i][1] = 1;
}
for (int i = 2; i <= n; ++i) {
int tmp;
scanf("%d", &tmp);
G[i].push_back(tmp);
G[tmp].push_back(i);
}
dfs(1, -1);
LL ans = 0;
for (int i = 1; i <= n; ++i) {
ans += dp[i][1];
if (ans > MOD) ans -= MOD;
}
std::cout << ans << std::endl;
}
return 0;
}