【题目链接】
【思路要点】
- 点分治,把黑点当-1,白点当1,同一根节点处的两条链能够拼成一条合法路径当且仅当它们权值相加为0且至少其中一条路径经过过两次自己的权值。
- 时间复杂度\(O(NLogN)\)。
【代码】
#include<bits/stdc++.h> using namespace std; #define MAXN 100005 struct edge {int dest; bool type; }; vector <edge> a[MAXN]; int n, root; long long ans; int size[MAXN], weight[MAXN]; int curr[MAXN*2][2], cnt[MAXN*2][2], Cnt; bool visited[MAXN]; void getroot(int pos, int father, int total) { size[pos] = 1; weight[pos] = 0; for (unsigned i = 0; i<a[pos].size(); i++) if (a[pos][i].dest != father && !visited[a[pos][i].dest]) { getroot(a[pos][i].dest, pos, total); size[pos] += size[a[pos][i].dest]; weight[pos] = max(weight[pos], size[a[pos][i].dest]); } weight[pos] = max(weight[pos], total-size[pos]); if (weight[pos]<weight[root]) root = pos; } int index(bool x) { if (x) return 1; else return -1; } void getans(int pos, int father, int len) { if (curr[len][0]) {curr[len][1]++; ans += cnt[2*n-len][1]+cnt[2*n-len][0]; } else {if (len == n) Cnt++; curr[len][0]++; ans += cnt[2*n-len][1]; } for (unsigned i = 0; i<a[pos].size(); i++) if (!visited[a[pos][i].dest] && a[pos][i].dest != father) getans(a[pos][i].dest, pos, len+index(a[pos][i].type)); if (curr[len][1]) curr[len][1]--; else curr[len][0]--; } void dfs(int pos, int father, int len) { if (curr[len][0]) {curr[len][1]++; cnt[len][1]++; } else {curr[len][0]++; cnt[len][0]++; } for (unsigned i = 0; i<a[pos].size(); i++) if (!visited[a[pos][i].dest] && a[pos][i].dest != father) dfs(a[pos][i].dest, pos, len+index(a[pos][i].type)); if (curr[len][1]) curr[len][1]--; else curr[len][0]--; } void clean(int pos, int father, int len) { cnt[len][0] = 0; cnt[len][1] = 0; for (unsigned i = 0; i<a[pos].size(); i++) if (!visited[a[pos][i].dest] && a[pos][i].dest != father) clean(a[pos][i].dest, pos, len+index(a[pos][i].type)); } long long cal(long long x) { return x*(x-1)/2; } void work(int pos, int total) { visited[pos] = true; getroot(pos, 0, total); for (unsigned i = 0; i<a[pos].size(); i++) if (!visited[a[pos][i].dest]) { Cnt = 0; getans(a[pos][i].dest, 0, n+index(a[pos][i].type)); ans -= cal(Cnt); dfs(a[pos][i].dest, 0, n+index(a[pos][i].type)); } ans += cnt[n][1]+cal(cnt[n][0]); for (unsigned i = 0; i<a[pos].size(); i++) if (!visited[a[pos][i].dest]) clean(a[pos][i].dest, 0, n+index(a[pos][i].type)); for (unsigned i = 0; i<a[pos].size(); i++) if (!visited[a[pos][i].dest]) { root = 0; getroot(a[pos][i].dest, 0, size[a[pos][i].dest]); work(root, size[a[pos][i].dest]); } } int main() { scanf("%d", &n); for (int i = 1; i<n; i++) { int x, y, z; scanf("%d%d%d", &x, &y, &z); a[x].push_back((edge){y, z == 1}); a[y].push_back((edge){x, z == 1}); } size[0] = weight[0] = n; root = 0; getroot(1, 0, n); work(root, n); printf("%lld\n", ans); return 0; }