题面
解法
考虑怎样的两条路径会相交
假设某一条路径为
(x,y)
(
x
,
y
)
如果
(x,y)
(
x
,
y
)
这条路径的某一个端点就是这两个点的
lca
l
c
a
,那么另一条路径的两个端点一定是一个在深度较大的子树里,另一个不在深度较小的点的子树里
如果
(x,y)
(
x
,
y
)
这条路径会经过
lca
l
c
a
,那么另一条路径的两个端点一定分别在
x
x
和的子树里
将字数信息变成
dfs
d
f
s
序的区间信息
那么,就变成询问一个点
x
x
在某一个范围内,在某一个范围内的问题了
这个问题可以直接用主席树解决
时间复杂度:
O(nlog n)
O
(
n
log
n
)
代码
#include <bits/stdc++.h>
#define N 100010
using namespace std;
template <typename node> void chkmax(node &x, node y) {x = max(x, y);}
template <typename node> void chkmin(node &x, node y) {x = min(x, y);}
template <typename node> void read(node &x) {
x = 0; int f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
struct Edge {
int next, num;
} e[N * 3];
struct SegmentTree {
struct Node {
int lc, rc, cnt;
} t[N * 38];
int tot;
int ins(int k, int l, int r, int x) {
int ret = ++tot; t[ret] = t[k]; t[ret].cnt++;
if (l == r) return ret; int mid = (l + r) >> 1;
if (x <= mid) t[ret].lc = ins(t[k].lc, l, mid, x);
else t[ret].rc = ins(t[k].rc, mid + 1, r, x);
return ret;
}
int query(int k1, int k2, int l, int r, int L, int R) {
if (L <= l && r <= R) return t[k2].cnt - t[k1].cnt;
int mid = (l + r) >> 1;
if (R <= mid) return query(t[k1].lc, t[k2].lc, l, mid, L, R);
if (L > mid) return query(t[k1].rc, t[k2].rc, mid + 1, r, L, R);
return query(t[k1].lc, t[k2].lc, l, mid, L, mid) + query(t[k1].rc, t[k2].rc, mid + 1, r, mid + 1, R);
}
} T;
struct Info {
int x, y;
bool operator < (const Info &b) const {
return x < b.x;
}
} b[N * 2];
struct Chain {
int x, y;
bool operator < (const Chain &a) const {
if (x == a.x) return y < a.y;
return x < a.x;
}
} a[N * 2];
int n, m, cnt, Time, d[N], rt[N], dfn[N], siz[N], f[N][21];
void add(int x, int y) {
e[++cnt] = (Edge) {e[x].next, y};
e[x].next = cnt;
}
void dfs(int x, int fa) {
for (int i = 1; i <= 20; i++)
f[x][i] = f[f[x][i - 1]][i - 1];
dfn[x] = ++Time, siz[x] = 1, d[x] = d[fa] + 1;
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num;
if (k == fa) continue;
f[k][0] = x; dfs(k, x);
siz[x] += siz[k];
}
}
int lca(int x, int y) {
if (d[x] < d[y]) swap(x, y);
for (int i = 20; ~i; i--)
if (d[f[x][i]] >= d[y]) x = f[x][i];
if (x == y) return x;
for (int i = 20; ~i; i--)
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
int main() {
read(n), read(m); cnt = n;
for (int i = 1; i < n; i++) {
int x, y; read(x), read(y);
add(x, y), add(y, x);
}
dfs(1, 0); int len = 0;
for (int i = 1; i <= m; i++) {
read(a[i].x), read(a[i].y);
int x = a[i].x, y = a[i].y;
b[++len] = (Info) {dfn[x], dfn[y]};
b[++len] = (Info) {dfn[y], dfn[x]};
}
sort(a + 1, a + m + 1);
sort(b + 1, b + len + 1); int j = 1;
for (int i = 1; i <= len; ) {
while (j < b[i].x) rt[j] = rt[j - 1], j++;
rt[j] = rt[j - 1];
while (j == b[i].x && i <= len) rt[j] = T.ins(rt[j], 1, n, b[i].y), i++;
j++;
}
while (j <= n) rt[j] = rt[j - 1], j++;
long long ans = 0;
for (int i = 1; i <= m; ) {
int j = i;
while (a[i].x == a[j].x && a[i].y == a[j].y && j <= m) j++;
int x = a[i].x, y = a[i].y;
int t = lca(x, y);
if (x == t) swap(x, y); ans--;
if (y == t) {
ans += T.query(rt[dfn[x] - 1], rt[dfn[x] + siz[x] - 1], 1, n, 1, n);
int z = x;
for (int k = 20; ~k; k--)
if (d[f[z][k]] > d[y]) z = f[z][k];
ans -= T.query(rt[dfn[x] - 1], rt[dfn[x] + siz[x] - 1], 1, n, dfn[z], dfn[z] + siz[z] - 1);
} else ans += T.query(rt[dfn[x] - 1], rt[dfn[x] + siz[x] - 1], 1, n, dfn[y], dfn[y] + siz[y] - 1);
i = j;
}
long long ansy = 1ll * m * (m - 1) / 2, t = __gcd(ans, ansy);
cout << ans / t << '/' << ansy / t << "\n";
return 0;
}