树上倍增
问题描述
给定一棵树,树的大小为
n
n
n且根节点为节点1,求有 多少组
(
i
,
j
,
k
)
(i,j,k)
(i,j,k)使得节点
k
k
k是节点
j
j
j的祖先且节点
j
j
j是节点
i
i
i的祖先。由于这个数可能很大,所以请输出对998244353取模后的结果。
输入
第一行包含一个正整数
n
(
1
≤
n
≤
1
0
6
)
n(1\le n \le 10^6)
n(1≤n≤106),表示树的节点个数
接下来
n
−
1
n-1
n−1行每行包括两个整数
u
,
v
u,v
u,v,表示节点
u
u
u和节点
v
v
v相连
输出
一行一个整数表示答案对998244353取模后的值
样例输入
4
1 2
2 3
2 4
样例输出
2
提示
有两个组满足答案,为
(
4
,
2
,
1
)
(4,2,1)
(4,2,1)和
(
3
,
2
,
1
)
(3,2,1)
(3,2,1)
AC代码
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const int maxn = 1e6 + 5;
const int mod = 998244353;
vector<int> mp[maxn];
int dep[maxn], size[maxn];
int n;
ll ans;
void dfs(int u, int f) {
dep[u] = dep[f] + 1;
size[u] = 1;
for(auto v : mp[u]) {
if(v == f) continue;
dfs(v, u);
size[u] += size[v];
}
ans += 1ll * (size[u] - 1) * (dep[u] - 1) % mod;
ans %= mod;
}
int main() {
scanf("%d", &n);
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
mp[u].push_back(v);
mp[v].push_back(u);
}
dfs(1, 0);
printf("%lld", ans);
return 0;
}
代码分析
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const int maxn = 1e6 + 5;
const int mod = 998244353;
vector<int> mp[maxn];
int dep[maxn], size[maxn];
int n;
ll ans;
void dfs(int u, int f) {//u为当前节点,f为上一层节点
//初始化节点1
dep[u] = dep[f] + 1;
size[u] = 1;
for(auto v : mp[u]) {//横向遍历当前节点的子节点
if(v == f) continue;
dfs(v, u);//纵向深入下一层
size[u] += size[v];
}
ans += 1ll * (size[u] - 1) * (dep[u] - 1) % mod;
ans %= mod;
}
int main() {
scanf("%d", &n);
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
//用向量组存储双向边
mp[u].push_back(v);
mp[v].push_back(u);
}
dfs(1, 0);
printf("%lld", ans);
return 0;
}