CF855G. Harry Vs Voldemort
Solution
考虑每增加一条边都会把路径上的边双都连成一个大边双,考虑合并 x x x和 y = f a x y = fa_x y=fax 这两个边双的贡献,分类讨论:
- 选取三个同边双内的点。
- 选取在同一个边双内选两个点,剩下一个在其他边双内。
- 选取三个来自不同边双的点。
第一个的贡献即为
A
s
z
x
3
A_{sz_x}^3
Aszx3
第二个的贡献即为
2
A
s
z
x
2
(
n
−
s
z
x
)
2A_{sz_x}^2(n - sz_x)
2Aszx2(n−szx)
第三个比较麻烦:
- 首先我们去掉在一二两种中重复的部分,这部分可能是 ( x , y , z ) (x,y,z) (x,y,z), ( y , x , z ) (y,x,z) (y,x,z), ( z , x , y ) (z,x,y) (z,x,y), ( z , y , x ) (z,y,x) (z,y,x)。
- 然后我们发现合并 ( x , y ) (x,y) (x,y)之后,就可以从本来 x x x子树中的点 p p p开始走到 y y y再回走到 x x x的另一个子树中的点 q q q,即 ( p , y , q ) (p,y,q) (p,y,q)和 ( q , y , p ) (q,y,p) (q,y,p)都可以选择,这一部分是新多出来的,这部分相当于 x x x的不同子树内的点对两两可达,于是我们再维护一个 h x = ∑ y s z y 2 h_x = \sum_{y}sz_y^2 hx=∑yszy2即可快速统计贡献。(对于 y y y这一部分的贡献同理)
具体统计方法见 C o d e Code Code。
并查集维护边双联通分量即可,时间复杂度 O ( n l g n ) O(nlgn) O(nlgn)。
Code
#include <bits/stdc++.h>
using namespace std;
template<typename T> inline bool upmin(T &x, T y) { return y < x ? x = y, 1 : 0; }
template<typename T> inline bool upmax(T &x, T y) { return x < y ? x = y, 1 : 0; }
#define MP(A,B) make_pair(A,B)
#define PB(A) push_back(A)
#define SIZE(A) ((int)A.size())
#define LEN(A) ((int)A.length())
#define FOR(i,a,b) for(int i=(a);i<(b);++i)
#define fi first
#define se second
#define int ll
typedef long long ll;
typedef unsigned long long ull;
typedef long double lod;
typedef pair<int, int> PR;
typedef vector<int> VI;
const lod eps = 1e-11;
const lod pi = acos(-1);
const int mods = 998244353;
const int oo = 1 << 30;
const ll loo = 1ll << 62;
const int MAXN = 600005;
const int INF = 0x3f3f3f3f;//1061109567
/*--------------------------------------------------------------------*/
inline int read() {
int f = 1, x = 0; char c = getchar();
while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); }
while (c >= '0' && c <= '9') { x = (x << 3) + (x << 1) + (c ^ 48); c = getchar(); }
return x * f;
}
vector<int> e[MAXN];
int fa[MAXN], f[MAXN], dep[MAXN];
ll sz[MAXN], num[MAXN], g[MAXN], h[MAXN], ans = 0, n;
int find(int x) { return f[x] == x ? f[x] : f[x] = find(f[x]); }
void dfs(int x, int father) {
sz[x] = 0, fa[x] = father, dep[x] = dep[father] + 1;
for (auto v : e[x]) if (v != father) dfs(v, x);
for (auto v : e[x]) {
if (v == father) continue;
ans += sz[x] * sz[v] * 2;
g[x] += g[v] + sz[v];
h[x] += sz[v] * sz[v];
sz[x] += sz[v];
}
++ sz[x];
h[x] += (n - sz[x]) * (n - sz[x]);
ans += (n - sz[x]) * (sz[x] - 1) * 2;
}
void merge(int x, int y) {
ans -= num[x] * (num[x] - 1) * (num[x] - 2); //part 1 x
ans -= num[y] * (num[y] - 1) * (num[y] - 2); //part 1 y
ans -= num[x] * (num[x] - 1) * (n - num[x]) * 2; //part 2 x
ans -= num[y] * (num[y] - 1) * (n - num[y]) * 2; //part 2 y
ans -= (sz[x] - num[x]) * num[x] * num[y] * 2 + (n - sz[x] - num[y]) * num[x] * num[y] * 2; //part 3.1
ans += num[y] * ((sz[x] - num[x]) * (sz[x] - num[x]) - (h[x] - (n - sz[x]) * (n - sz[x]))); //part 3.2 x
ans += num[x] * ((n - sz[x] - num[y]) * (n - sz[x] - num[y]) - (h[y] - sz[x] * sz[x])); //part 3.2 y
f[x] = y, num[y] += num[x], h[y] += h[x] - sz[x] * sz[x] - (n - sz[x]) * (n - sz[x]);
ans += num[y] * (num[y] - 1) * (num[y] - 2); //part 1 new
ans += num[y] * (num[y] - 1) * (n - num[y]) * 2; //part 2 new
}
signed main() {
n = read();
for (int i = 1, u, v; i < n ; ++ i) u = read(), v = read(), e[u].PB(v), e[v].PB(u);
for (int i = 1; i <= n ; ++ i) f[i] = i, num[i] = 1;
dfs(1, 0);
printf("%lld\n", ans);
int Case = read();
while (Case --) {
int u = read(), v = read(), U = find(u), V = find(v);
while (U != V) {
if (dep[U] < dep[V]) swap(U, V);
merge(U, find(fa[U]));
U = find(U);
}
printf("%lld\n", ans);
}
return 0;
}