洛谷 P2634 [国家集训队]聪聪可可
题意
给一棵 n n n 个节点的树,边带权。任选两点(可相同),求两点之间简单路径长度恰好是 3 3 3 的倍数的概率。
解法
树上路径询问?立即推:点分治!
- 枚举每个点作为
lca
,维护子树的点到lca
距离模 3 3 3 的数量,然后先遍历子树先更新答案再更新数量即可。复杂度为 O ( n l o g n ) O(nlogn) O(nlogn)。 - 当然这题也可以直接树形dp,码量一下子就小了很多,原理也是类似的。复杂度为 O ( k n ) O(kn) O(kn)。 k k k 为模数,本题 k = 3 k=3 k=3 。
代码
点分治:
#pragma region
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <vector>
using namespace std;
typedef long long ll;
#define rep(i, a, n) for (int i = a; i <= n; ++i)
#define per(i, a, n) for (int i = n; i >= a; --i)
#pragma endregion
const int maxn = 2e4 + 5;
int n;
vector<pair<int, int>> g[maxn];
bool vis[maxn];
int sz[maxn], d[maxn], rt, c[3];
void dfs_rt(int u, int f, int tot) {
sz[u] = 1;
int maxx = 0;
for (auto e : g[u]) {
int v = e.first;
if (vis[v] || v == f) continue;
dfs_rt(v, u, tot);
sz[u] += sz[v];
maxx = max(maxx, sz[v]);
}
maxx = max(maxx, tot - sz[u]);
if (maxx * 2 <= tot) rt = u;
}
int cnt;
void dfs_ans(int u, int f, int &ans) {
++cnt;
ans += c[(3 - (d[u] % 3)) % 3] + (d[u] % 3 == 0);
for (auto e : g[u]) {
int v = e.first, w = e.second;
if (vis[v] || v == f) continue;
d[v] = d[u] + w;
dfs_ans(v, u, ans);
}
}
void dfs_c(int u, int f) {
c[d[u] % 3]++;
for (auto e : g[u]) {
int v = e.first;
if (vis[v] || v == f) continue;
dfs_c(v, u);
}
}
int work(int u, int f, int tot) {
dfs_rt(u, f, tot);
u = rt, vis[u] = 1, d[u] = 0;
int ans = 0;
for (auto e : g[u]) {
int v = e.first, w = e.second;
if (vis[v]) continue;
cnt = 0, d[v] = w;
dfs_ans(v, u, ans);
sz[v] = cnt;
dfs_c(v, u);
}
c[0] = c[1] = c[2] = 0;
for (auto e : g[u]) {
int v = e.first;
if (vis[v]) continue;
ans += work(v, u, sz[v]);
}
return ans;
}
int main() {
scanf("%d", &n);
rep(i, 1, n - 1) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
g[u].push_back({v, w});
g[v].push_back({u, w});
}
int ans = work(1, 0, n) * 2 + n, sum = n * n;
int g = __gcd(ans, sum);
printf("%d/%d\n", ans / g, sum / g);
}
树形dp:
#pragma region
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <vector>
using namespace std;
typedef long long ll;
#define rep(i, a, n) for (int i = a; i <= n; ++i)
#define per(i, a, n) for (int i = n; i >= a; --i)
#pragma endregion
const int maxn = 2e5 + 5;
int n;
vector<pair<int, int>> g[maxn];
int dp[maxn][3], ans;
void dfs(int u, int f) {
dp[u][0] = 1;
for (auto e : g[u]) {
int v = e.first, w = e.second;
if (v == f) continue;
dfs(v, u);
rep(i, 0, 2) ans += 2 * (dp[v][i] * dp[u][((3 - i - w) % 3 + 3) % 3]);
rep(i, 0, 2) dp[u][(i + w) % 3] += dp[v][i];
}
}
int main() {
scanf("%d", &n);
rep(i, 1, n - 1) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
g[u].push_back({v, w});
g[v].push_back({u, w});
}
dfs(1, 0);
int sum = n * n;
ans += n;
int g = __gcd(ans, sum);
printf("%d/%d\n", ans / g, sum / g);
}