一道很有意思的树上点对 数量计算的问题。。
他实际上是普通的经典 树上点对计算框架的 一个功能扩展。。
也就是加入了同值的概念。。仅此而已。亮点在于深度的存储方式 给计算带来的便利
如果代码看不懂 可以去找一些基本的 树上点对问题统计的题目
实际上树上直径dp。。本质上也是树上点对问题 。
因为直径可以看做是寻找树上最远的两点之间的距离。
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ll __int128_t
#define ar array<int, 2>
#define arr array<int, 3>
int n, m, k, inf = 1LL << 61, mod = 998244353;// 1e9+7;
const int N = 5e5 + 50;
map<int, ar>f[N];//节点子树内的 [颜色种类:[数量:深度和]]
vector<int>mp[N];
int a[N], ans;
void dfs(int u, int p, int d) {
f[u][a[u]][0]++;
f[u][a[u]][1] += d;
for (int v : mp[u]) {
if (v == p)
continue;
dfs(v, u, d + 1);
if (f[u].size() < f[v].size())
swap(f[u], f[v]);
for (auto[t, x] : f[v]) {
auto &y = f[u][t];
ans += x[0] * (y[1] - y[0] * d);
ans += y[0] * (x[1] - x[0] * d);
// 左点数* 右距离和。。右点数*左距离和。。
y[0] += x[0];
y[1] += x[1];
}
}
};
void solve() {
cin >> n;
for (int i = 1; i < n; ++i) {
int x, y;
cin >> x >> y;
mp[x].push_back(y);
mp[y].push_back(x);
}
for (int i = 1; i <= n; ++i)
cin >> a[i];
dfs(1, 0, 0);
cout << ans;
};
//实际上是很普通的。。启发式合并+树上点对问题的计算方式
// 主要是存储的方式。很有新意 f数组的结构。
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout << fixed << setprecision(15);
#ifdef DEBUG
freopen("../1.in", "r", stdin);
#endif
//init_f();
//init();
//expr();
// int T; cin >> T; while(T--)
solve();
return 0;
}
增:今天刷题看到一个简单版本的的题目 也就是没有颜色区别 单纯求任意两点的距离和。
https://atcoder.jp/contests/typical90/tasks/typical90_am
039 - Tree Distance
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ll __int128_t
#define ar array<int, 2>
#define arr array<int, 3>
int n, m, k, inf = 1LL << 61, mod = 998244353;// 1e9+7;
const int N = 5e5 + 50;
vector<int>mp[N];
int ans, cnt[N], len[N];
void dfs(int u, int p, int d) {
cnt[u] = 1;
len[u] += d;
for (int v : mp[u]) {
if (v == p)
continue;
dfs(v, u, d + 1);
ans += cnt[u] * (len[v] - cnt[v] * d);
ans += cnt[v] * (len[u] - cnt[u] * d);
cnt[u] += cnt[v];
len[u] += len[v];
//也可以用边的贡献来计算 树上距离和 更简单。 不过上面这套。。更适用于上面那道题的扩展。
//下面还有一道题。。就用下面这种边的贡献来求 更好
//ans += cnt[v] * (n - cnt[v]);
//cnt[u] += cnt[v];
}
};
void solve() {
cin >> n;
for (int i = 1; i < n; ++i) {
int x, y;
cin >> x >> y;
mp[x].push_back(y);
mp[y].push_back(x);
}
dfs(1, 0, 0);
cout << ans;
};
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout << fixed << setprecision(15);
#ifdef DEBUG
freopen("../1.in", "r", stdin);
#endif
//init_f();
//init();
//expr();
// int T; cin >> T; while(T--)
solve();
return 0;
}
下面这题 又是另外一个方向的 距离和的扩展。。又兴趣的自己去了解 他这个题解写的特别好 虽然短短几句。。但是都说到关键的点子上了。
不过他代码没有我简洁 哈哈哈
https://www.cnblogs.com/wscqwq/p/17591019.html
G - Avoid Straight Line
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ll __int128_t
#define ar array<int, 2>
#define arr array<int, 3>
int n, m, k, inf = 1LL << 61, mod = 998244353;// 1e9+7;
const int N = 5e5 + 50;
vector<int>mp[N];
int ans, cnt[N], len[N];
int C(int n, int m) {
int ans = 1;
for (int i = 1; i <= m; i++) {
ans = ans * (n - m + i) / i; // 注意一定要先乘再除
}
return ans;
}
//https://blog.csdn.net/m0_37149062/article/details/122522676
void dfs(int u, int p, int d) {
cnt[u] = 1;
for (int v : mp[u]) {
if (v == p)
continue;
dfs(v, u, d + 1);
ans += cnt[v] * (n - cnt[v]);
cnt[u] += cnt[v];
}
};
void solve() {
cin >> n;
for (int i = 1; i < n; ++i) {
int x, y;
cin >> x >> y;
mp[x].push_back(y);
mp[y].push_back(x);
}
dfs(1, 0, 0);
cout << C(n, 3) + C(n, 2) - ans;
};
// 经过u点作为第三个点 的所有两点的组合。。。可以转化成任意两点的距离和问题 len-1 ..
// 为什么要-1
//因为 [1,2,3] 1-3 距离是2 但是实际上中间只能放一个点。。怎么求这个多余的这个点的个数呢。。
// 他实际上又等价于 任意两点 点对的个数 也就是 C(n,2)
signed main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout << fixed << setprecision(15);
#ifdef DEBUG
freopen("../1.in", "r", stdin);
#endif
//init_f();
//init();
//expr();
// int T; cin >> T; while(T--)
solve();
return 0;
}