Solution
-
记 s → t s→t s→t 为包含点 u u u 的一条路径,显然所有的 s → t s→t s→t 能组成一个连通块(因为路径可以拆成 s → u , u → t s→u,u→t s→u,u→t),而这个连通块的边数就是能与 u u u 开展贸易活动的城市个数。
-
记这个连通块为 G ( u ) G(u) G(u),显然 G ( u ) G(u) G(u) 也能看成:连通所有点 s , t s,t s,t 和 u u u 的最小生成树
-
先考虑开 n n n 棵线段树分别维护 G ( u ) G(u) G(u),对于 s → t s→t s→t 路径上每个点 u u u, G ( u ) G(u) G(u) 都要加上 s , t s,t s,t 这两个点,线段树以 d f s dfs dfs 序为下标,节点 x x x 维护以下信息 (对应区间 [ l , r ] [l,r] [l,r]):
f ( x ) : f(x): f(x): G ( u ) G(u) G(u) 中 d f s dfs dfs 序在 [ l , r ] [l,r] [l,r] 中的点(记这些点的集合为 H ( x ) H(x) H(x)),加上根节点,组成的连通块的边数
s ( x ) : s(x): s(x): H H H 中 d f s dfs dfs 序最小的点
t ( x ) : t(x): t(x): H H H 中 d f s dfs dfs 序最大的点 -
记 x x x 的左右儿子分别为 x 2 , x 3 x2,x3 x2,x3,则一般情况下:
f ( x ) = f ( x 2 ) + f ( x 3 ) − d e e p ( l c a ( t ( x 2 ) , t ( x 3 ) ) f(x)=f(x2)+f(x3)-deep(lca(t(x2), t(x3)) f(x)=f(x2)+f(x3)−deep(lca(t(x2),t(x3))
s ( x ) = s ( x 2 ) s(x)=s(x2) s(x)=s(x2)
t ( x ) = t ( x 3 ) t(x)=t(x3) t(x)=t(x3)当然还有一些 H ( x 2 ) H(x2) H(x2) 或 H ( x 3 ) H(x3) H(x3) 为空集的情况需要特判
-
具体实现中,可以把 s → t s→t s→t 做树上差分,即拆成:在 s , t s,t s,t 处出现次数 + 1 +1 +1 , l c a ( s , t ) , f a ( l c a ( s , t ) ) lca(s,t),fa(lca(s,t)) lca(s,t),fa(lca(s,t)) 处出现次数 − 1 -1 −1 (上述出现次数均为点 s , t s,t s,t 的出现次数),因此,线段树的叶子节点还要记录每个点的出现次数
-
然后离线下来, d f s dfs dfs 整棵树一遍,在回溯的时候,把儿子的线段树合并到该点,并执行位于该点的 + 1 , − 1 +1,-1 +1,−1 修改
-
记 r t [ u ] rt[u] rt[u] 为 u u u 对应线段树的根, G ( u ) G(u) G(u) 的边数即 f ( r t [ u ] ) − d e e p ( l c a ( s ( r t [ u ] ) , t ( r t [ u ] ) ) ) f(rt[u])-deep(lca(s(rt[u]),t(rt[u]))) f(rt[u])−deep(lca(s(rt[u]),t(rt[u])))
-
注意没有限制 u < v u<v u<v,即答案要除以 2 2 2
-
使用欧拉序求 l c a lca lca 即可做到 O ( n l o g n ) O(nlogn) O(nlogn) 的时间复杂度
Code
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define ll long long
template <class t>
inline void read(t & res)
{
char ch;
while (ch = getchar(), !isdigit(ch));
res = ch ^ 48;
while (ch = getchar(), isdigit(ch))
res = res * 10 + (ch ^ 48);
}
const int e = 2e5 + 5;
vector<int>g[e], h[e];
ll ans;
int n, m, rt[e], logn[e], st[e][18], dep[e], nxt[e * 2], go[e * 2], num, adj[e], fa[e];
int dfn1[e], dfn2[e], pool;
struct node
{
int l, r, cnt, f, s, t;
}c[e * 30];
inline void add(int x, int y)
{
nxt[++num] = adj[x];
adj[x] = num;
go[num] = y;
nxt[++num] = adj[y];
adj[y] = num;
go[num] = x;
}
inline void dfs1(int u, int pa)
{
dep[u] = dep[pa] + 1;
dfn1[u] = ++dfn1[0];
dfn2[u] = ++dfn2[0];
st[dfn1[0]][0] = u;
fa[u] = pa;
for (int i = adj[u]; i; i = nxt[i])
{
int v = go[i];
if (v == pa) continue;
dfs1(v, u);
st[++dfn1[0]][0] = u;
}
}
inline int lca(int x, int y)
{
if (!x || !y) return 0;
if (dfn1[x] > dfn1[y]) swap(x, y);
int l = dfn1[x], r = dfn1[y], k = logn[r - l + 1], u = st[l][k],
v = st[r - (1 << k) + 1][k];
return dep[u] < dep[v] ? u : v;
}
inline void collect(int x)
{
int l = c[x].l, r = c[x].r;
c[l].s ? c[x].s = c[l].s : c[x].s = c[r].s;
c[r].t ? c[x].t = c[r].t : c[x].t = c[l].t;
c[x].f = c[l].f + c[r].f - dep[lca(c[l].t, c[r].s)];
}
inline int merge(int x, int y, int l, int r)
{
if (!x || !y) return x ^ y;
if (l == r)
{
c[x].cnt += c[y].cnt;
c[x].f |= c[y].f;
c[x].s |= c[y].s;
c[x].t |= c[y].t;
return x;
}
int mid = l + r >> 1;
c[x].l = merge(c[x].l, c[y].l, l, mid);
c[x].r = merge(c[x].r, c[y].r, mid + 1, r);
collect(x);
return x;
}
inline void insert(int &x, int l, int r, int pos, int v)
{
if (!x) x = ++pool;
if (l == r)
{
c[x].cnt += v;
if (!c[x].cnt) c[x].f = c[x].s = c[x].t = 0;
else c[x].f = dep[pos], c[x].s = c[x].t = pos;
return;
}
int mid = l + r >> 1;
if (dfn2[pos] <= mid) insert(c[x].l, l, mid, pos, v);
else insert(c[x].r, mid + 1, r, pos, v);
collect(x);
}
inline void init()
{
int i, j;
logn[0] = -1;
for (i = 1; i <= dfn1[0]; i++) logn[i] = logn[i >> 1] + 1;
for (j = 1; (1 << j) <= dfn1[0]; j++)
for (i = 1; i + (1 << j) - 1 <= dfn1[0]; i++)
{
int u = st[i][j - 1], v = st[i + (1 << j - 1)][j - 1];
st[i][j] = (dep[u] < dep[v] ? u : v);
}
}
inline void dfs2(int u, int pa)
{
for (int i = adj[u]; i; i = nxt[i])
{
int v = go[i];
if (v == pa) continue;
dfs2(v, u);
rt[u] = merge(rt[u], rt[v], 1, n);
}
for (auto v : g[u]) insert(rt[u], 1, n, v, 1);
for (auto v : h[u]) insert(rt[u], 1, n, v, -1);
ans += c[rt[u]].f - dep[lca(c[rt[u]].s, c[rt[u]].t)];
}
int main()
{
int i, x, y;
read(n); read(m);
for (i = 1; i < n; i++) read(x), read(y), add(x, y);
dfs1(1, 0);
init();
while (m--)
{
read(x); read(y);
int z = lca(x, y);
g[x].pb(x); g[x].pb(y); g[y].pb(x); g[y].pb(y);
h[z].pb(x); h[z].pb(y); h[fa[z]].pb(x); h[fa[z]].pb(y);
}
dfs2(1, 0);
cout << ans / 2 << endl;
return 0;
}