题意
有一棵
n
n
n 个点的树,有
m
m
m 条链
(
u
i
,
v
i
)
(u_i,v_i)
(ui,vi),问有多少无序点对
(
i
,
j
)
(i,j)
(i,j),满足第
i
i
i 条链和第
j
j
j 条链只有一个交点。
其中,
n
,
m
≤
3
e
5
n,m\leq 3e5
n,m≤3e5。
分析
参考自:这个大佬的博客,图也是盗他的
两条链相交只有一个交点,我们可以发现,这个交点一定是至少其中一条链的
l
c
a
lca
lca。
于是我们分类讨论一下:
①:交点是两条链的
l
c
a
lca
lca。
②:交点是其中一条链的
l
c
a
lca
lca,不是另一条的
l
c
a
lca
lca。
我们把这两种情况的答案加起来,就是最终的答案。
如图,左边的为第一种情况,右边的为第二种情况。
我们先来看第一种情况怎么求。现在考虑
l
c
a
lca
lca 相同的情况。设
a
a
a 表示
u
u
u 在
l
c
a
lca
lca 的哪棵子树中,
b
b
b 表示
v
v
v 在
l
c
a
lca
lca 的哪棵子树中。假设
l
c
a
lca
lca 是
u
u
u 的
k
k
k 级祖先,那么
a
a
a 为
u
u
u 的
k
−
1
k-1
k−1 级祖先。如果
(
u
,
v
)
(u,v)
(u,v) 是一条直的链的话,也就是
u
u
u 不存在
−
1
-1
−1 级祖先,我们令
a
=
−
1
a = -1
a=−1。为了方便,我们设
a
<
b
a<b
a<b,如果
a
≥
b
a\ge b
a≥b 的话就交换
u
,
v
u,v
u,v 和
a
,
b
a,b
a,b。
假设
a
,
b
a,b
a,b 都存在,也就是
(
u
,
v
)
(u,v)
(u,v) 不是一条直的链的话,那么两条链
(
u
1
,
v
1
)
(u_1,v_1)
(u1,v1) 和
(
u
2
,
v
2
)
(u_2,v_2)
(u2,v2) 只有一个交点,说明
a
1
,
b
1
,
a
2
,
b
2
a_1,b_1,a_2,b_2
a1,b1,a2,b2 是四个不同的数。如果
a
a
a 是
−
1
-1
−1 的话,我们就可以不管这个
a
a
a,
b
b
b 也同理。
于是我们对相同
l
c
a
lca
lca 的链进行操作。假设当前枚举了
k
k
k 条链,我们用容斥原理求与第
k
+
1
k+1
k+1 条链只有一个交点的对数。设
c
n
t
i
cnt_i
cnti 为
i
i
i 的出现次数,那么方案数就为
k
−
c
n
t
a
k
+
1
−
c
n
t
b
k
+
1
+
c
n
t
p
a
i
r
(
a
k
+
1
,
b
k
+
1
)
k-cnt_{a_{k+1}}-cnt_{b_{k+1}}+cnt_{pair(a_{k+1},b_{k+1})}
k−cntak+1−cntbk+1+cntpair(ak+1,bk+1) 。
然后考虑第二种情况。我们在排序时,先按
l
c
a
lca
lca 的深度从小到大排序,深度相同时按
l
c
a
lca
lca 从小到大排序,这样我们就保证了
l
c
a
lca
lca 相同的那些链是连续的一段,而且整体的深度是从小到大的。我们看看第二种情况的特点,就是有一条弯的链
x
x
x,还有另一条链
y
y
y 穿过这条链的
l
c
a
lca
lca 并经过了
f
a
l
c
a
fa_{lca}
falca。我们发现,
y
y
y 的
l
c
a
lca
lca 的深度一定小于
x
x
x 的,而且
y
y
y 一定只有一个端点在链
x
x
x 的
l
c
a
lca
lca 的子树中,且不在
x
a
x_a
xa 和
x
b
x_b
xb 的子树中。因此,我们用树状数组维护一下
d
f
s
dfs
dfs 序即可。就是用总的减去不合法的,总的是
l
c
a
lca
lca 子树中的点的个数, 如果
x
a
x_a
xa 存在,那就减去在
x
a
x_a
xa 子树中的点,对于
x
b
x_b
xb 同理。
第一种情况可以用两个
m
a
p
map
map 解决,第二种情况用一个树状数组解决。
总的复杂度是
O
(
m
l
o
g
n
+
m
l
o
g
m
)
O(mlogn+mlogm)
O(mlogn+mlogm) 的。
代码如下
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 3e5 + 5;
#define pii pair<int, int>
vector<int> E[N];
int n, dep[N], st[N][20], fa[N], L[N], R[N], tr[N], dft;
struct node{
int x, y, a, b, lca;
}q[N];
int cmp(node A, node B){
if(dep[A.lca] != dep[B.lca]) return dep[A.lca] < dep[B.lca];
return A.lca < B.lca;
}
void dfs(int a, int pre){
L[a] = ++dft;
for(int b: E[a]){
if(b != pre){
st[b][0] = fa[b] = a;
dep[b] = dep[a] + 1;
dfs(b, a);
}
}
R[a] = dft;
}
void build_st(){
for(int i = 1; i < 20; i++)
for(int j = 1; j <= n; j++) st[j][i] = st[st[j][i - 1]][i - 1];
}
int Lca(int a, int b){
if(dep[b] > dep[a]) swap(a, b);
int di = dep[a] - dep[b];
for(int i = 19; i >= 0; i--) if(di >> i & 1) a = st[a][i];
if(a == b) return a;
for(int i = 19; i >= 0; i--) if(st[a][i] != st[b][i]) a = st[a][i], b = st[b][i];
return st[a][0];
}
int up(int a, int x){
if(x <= 0) return a;
for(int i = 19; i >= 0; i--) if(x >> i & 1) a = st[a][i];
return a;
}
void add(int i, int x){
for(; i <= n; i += i & -i) tr[i] += x;
}
int query(int i){
int s = 0;
for(; i > 0; i -= i & -i) s += tr[i];
return s;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n;
for(int i = 1, a, b; i < n; i++){
cin >> a >> b;
E[a].push_back(b);
E[b].push_back(a);
}
dfs(1, 0);
build_st();
int m;
cin >> m;
for(int i = 1, x, y; i <= m; i++){
cin >> x >> y;
int lca = Lca(x, y);
int a = up(x, dep[x] - dep[lca] - 1), b = up(y, dep[y] - dep[lca] - 1);
if(a == lca) a = -1;
if(b == lca) b = -1;
if(a > b) swap(a, b), swap(x, y);
q[i] = {x, y, a, b, lca};
}
sort(q + 1, q + m + 1, cmp);
LL ans1 = 0, ans2 = 0;
for(int i = 1, j = 1; i <= m; i++, j = i){
map<pii, int> f;
map<int, int> cnt;
while(q[j + 1].lca == q[i].lca) j++;
int lca = q[i].lca;
for(int k = i, x, y, a, b; k <= j; k++){
x = q[k].x, y = q[k].y, a = q[k].a, b = q[k].b;
ans1 += k - i - cnt[a] - cnt[b] + f[make_pair(a, b)];
if(a != -1) cnt[a]++;
if(b != -1) cnt[b]++;
if(a != -1 && b != -1) f[make_pair(a, b)]++;
ans2 += query(R[lca]) - query(L[lca] - 1);
if(a != -1) ans2 -= query(R[a]) - query(L[a] - 1);
if(b != -1) ans2 -= query(R[b]) - query(L[b] - 1);
}
for(int k = i; k <= j; k++) add(L[q[k].x], 1), add(L[q[k].y], 1);
i = j;
}
cout << ans1 + ans2 << '\n';
return 0;
}