题意:给你一颗n个节点的无环图,每个点有一个值,定义 f ( i , x ) = s u m j = 1 n a i j x j − 1 f(i,x)=sum_{j=1}^na_{ij}x^{j-1} f(i,x)=sumj=1naijxj−1, a i j a_{ij} aij表示从 i i i节点到 j j j节点路径上不同的点值个数,先在要求求出所有的 f ( i , 19560929 ) m o d ( 1 e 9 + 7 ) f(i,19560929)mod(1e9+7) f(i,19560929)mod(1e9+7)和 f ( i , 19560929 ) m o d ( 1 e 9 + 9 ) f(i,19560929)mod(1e9+9) f(i,19560929)mod(1e9+9)
解法: ∑ n ≤ 5000 \sum n \leq 5000 ∑n≤5000, O ( n 2 ) O(n^2) O(n2)暴力即可, O ( n ) O(n) O(n)可以暴力遍历求出一个 i i i对应的 f ( i , 19560929 ) m o d ( 1 e 9 + 7 ) f(i,19560929)mod(1e9+7) f(i,19560929)mod(1e9+7)和 f ( i , 19560929 ) m o d ( 1 e 9 + 9 ) f(i,19560929)mod(1e9+9) f(i,19560929)mod(1e9+9)
AcCode:
#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
// #define inf 1<<63-1
#define int long long
#define pint std::pair<long long,long long>
const int N = 1e4 + 100;
const int mod1 = 1e9 + 7;
const int mod2 = 1e9 + 9;
const int x = 19560929;
std::vector<int> vec[N];
int value[N];
int vis[N];
int inv1[N],inv2[N];
inline int quick_pow(int a, int b, int p) {
int ans = 1;
a %= p;
while (b) {
if (b & 1) ans = ans * a % p;
b >>= 1;
a = a * a % p;
}
return ans;
}
inline pint dfs(int now, int pre, int num) {
if (!vis[value[now]]) {
num++;
}
vis[value[now]]++;
int ans1 = inv1[now - 1];
int ans2 = inv2[now - 1];
ans1 = 1ll*ans1 * num % mod1;
ans2 = 1ll*ans2 * num % mod2;
pint res = std::make_pair(ans1, ans2);
for (auto &v : vec[now]) {
if (v == pre) continue;
pint ret = dfs(v, now, num);
res.first = (res.first + ret.first) % mod1;
res.second = (res.second + ret.second) % mod2;
}
vis[value[now]]--;
return res;
}
signed main() {
int t; scanf("%lld", &t);
inv1[0] = 1;
inv2[0] = 1;
for (int i = 1; i <= 5005; i++) {
inv1[i] = inv1[i - 1] * x % mod1;
//int temp = quick_pow(x, i, mod1);
//if (temp != inv1[i]) printf("error inv1 %d\n", i);
inv2[i] = inv2[i - 1] * x % mod2;
//temp = quick_pow(x, i, mod2);
//if (temp != inv2[i]) printf("error inv %d\n", i);
}
/*int out1 = inv1[0]%mod1 + 2 * inv1[1]%mod1 + 2 * inv1[2]%mod1 + 2 * inv1[3]%mod1 + 2 * inv1[5] % mod1;
std::cout << out1 << std::endl;*/
while (t--) {
int n; scanf("%lld", &n);
for (int i = 1; i <= n; i++) vec[i].clear();
for (int i = 2; i <= n; i++) {
int pre; scanf("%lld", &pre);
vec[i].push_back(pre);
vec[pre].push_back(i);
}
for (int i = 1; i <= n; i++) scanf("%lld", &value[i]);
for (int i = 1; i <= n; i++) {
//for (int j = 1; j <= n; j++) vis[j] = 0;
pint ans = dfs(i, 0, 0);
printf("%lld %lld\n", (ans.first%mod1+mod1)%mod1, (ans.second%mod2+mod2)%mod2);
}
}
}