题意:
给出一个树,共n个节点。
有m条互不相同的树上路径。
现在让你随机选择2条路径,问两条路径存在包含关系的概率(输出最简分数)。
n,m<=100000
做法:
假如我们把所有路径存下来,对于x,y的一条路,在x处打一个y标记,在y处打一个x标记,那么我们查询的时候,只要查询x,y两侧分别有多少标记就可以了,因为x,y两侧的点形成的路径肯定包含x-y。
再说得明白一点,就是假如你对每个点都开一棵线段树,每次假如x-y有一条路径,就在x的线段树上y的位子+1,在y的线段树上x的位子+1。假如现在有一个询问x-y,形象地看成x在左y在右,你就要询问x左侧所有点的线段树中y右侧所有位子的值的和。
然后我们考虑三种情况:
设x, y的lca为z。
1.假如x!=z且y!=z
x左侧相当于是x的子树,y右侧也相当于y的子树。
2.假如x,y中有一个=z
x左侧相当于是x的子树,y右侧相当于除去y在x路径上的儿子的子树外的部分。
3.假如x=y=z
那相当于只有一个点
x左侧相当于是x的子树,y右侧除去x子树的部分。
于是我们发现都涉及到了子树,而子树在dfs序上是连续的一段,所以我们跑出dfs序,用主席树/可持久化线段树维护。
最后注意代码细节即可。
代码
/*************************************************************
Problem: bzoj 3772 精神污染
User: fengyuan
Language: C++
Result: Accepted
Time: 4764 ms
Memory: 63228 kb
Submit_Time: 2017-12-23 11:47:03
*************************************************************/
#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<vector>
#include<cstdlib>
#define mid ((l+r)/2)
#define pb push_back
using namespace std;
typedef long long LL;
const int N = 100010, M = 4000010;
int n, m, cnt, clk, tot;
int head[N], depth[N], f[N][19], in[N], out[N], L[M], R[M], rt[N], sum[M];
struct Edge{ int to, nex; }e[N<<1];
struct Questions{ int x, y; }que[N];
vector<int> vec[N];
inline void add(int x, int y)
{
e[++ cnt].to = y;
e[cnt].nex = head[x];
head[x] = cnt;
}
inline void dfs(int u, int last, int s)
{
depth[u] = s; f[u][0] = last; in[u] = ++ clk;
for(int i = head[u]; i; i = e[i].nex)
if(e[i].to != last) dfs(e[i].to, u, s+1);
out[u] = clk;
}
inline int LCA(int x, int y)
{
if(depth[x] < depth[y]) swap(x, y);
int tmp = depth[x] - depth[y];
for(int i = 17; i >= 0; i --)
if((tmp>>i)&1) x = f[x][i];
if(x == y) return x;
for(int i = 17; i >= 0; i --)
if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
inline void build(int &rt, int l, int r)
{
rt = ++ tot; sum[rt] = 0;
if(l == r) return;
build(L[rt], l, mid); build(R[rt], mid+1, r);
}
inline void update(int pre, int &rt, int l, int r, int x)
{
if(!rt || rt == pre) rt = ++ tot, sum[rt] = sum[pre];
sum[rt] ++;
if(l == r) return;
if(!L[rt]) L[rt] = L[pre];
if(!R[rt]) R[rt] = R[pre];
if(x <= mid) update(L[pre], L[rt], l, mid, x);
else update(R[pre], R[rt], mid+1, r, x);
}
inline int query(int u, int v, int l, int r, int x, int y)
{
if(l == x && r == y) return sum[v]-sum[u];
if(y <= mid) return query(L[u], L[v], l, mid, x, y);
else if(x > mid) return query(R[u], R[v], mid+1, r, x, y);
else return query(L[u], L[v], l, mid, x, mid) + query(R[u], R[v], mid+1, r, mid+1, y);
}
inline bool isAncestor(int x, int y){ return in[x] <= in[y] && out[x] >= out[y]; }
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i < n; i ++) {
int x, y; scanf("%d%d", &x, &y);
add(x, y); add(y, x);
}
dfs(1, 0, 0);
for(int j = 1; j <= 17; j ++)
for(int i = 1; i <= n; i ++) f[i][j] = f[f[i][j-1]][j-1];
for(int i = 1; i <= m; i ++) {
int x, y; scanf("%d%d", &x, &y);
if(in[x] > in[y]) swap(x, y);
vec[in[x]].pb(in[y]);
vec[in[y]].pb(in[x]);
que[i].x = x; que[i].y = y;
}
build(rt[0], 1, n);
for(int i = 1; i <= n; i ++) {
rt[i] = rt[i-1];
for(int j = 0; j < vec[i].size(); j ++)
update(rt[i-1], rt[i], 1, n, vec[i][j]);
}
LL p = 0, q = 1LL*m*(m-1)/2;
for(int i = 1; i <= m; i ++) {
int x = que[i].x, y = que[i].y, z = LCA(x, y);
if(x != z && y != z) {
p += query(rt[in[x]-1], rt[out[x]], 1, n, in[y], out[y]);
p --;
} else if(x != z || y != z){
if(y != z) swap(x, y);
int t = x;
for(int j = 17; j >= 0; j --)
if(f[t][j] && !isAncestor(f[t][j], y)) t = f[t][j];
p += query(rt[in[x]-1], rt[out[x]], 1, n, 1, n);
p -= query(rt[in[x]-1], rt[out[x]], 1, n, in[t], out[t]);
p --;
} else {
p += query(rt[in[x]-1], rt[out[x]], 1, n, 1, n);
p -= query(rt[in[x]-1], rt[out[x]], 1, n, in[x], out[x]);
}
}
LL g = __gcd(p, q);
p /= g; q /= g;
if(!p) puts("0"); else printf("%lld/%lld\n", p, q);
return 0;
}