双重祖先
Problem Description
给定两棵有根树,两棵树均有n个节点,且根均为1号点。
问有多少对(u,v)满足:在给定的两棵树中u均为v的祖先。
Input
本题中每个测试点仅包含一组测试数据
第一行一个正整数n,表示树的规模
接下来的n-1行,每行两个整数x, y,表示在第一棵有根树中x点与y点之间有一条连边。
再接下来的n-1行,每行两个整数x, y,表示在第二棵有根树中x点与y点之间有一条连边。
Output
输出数据仅一行,表示满足题意的(u, v)点对数
Sample Input
4
1 2
2 3
3 4
1 2
2 3
2 4
Sample Output
5
题解:
先对一棵树重链剖分,重新编号,
i
d
[
u
]
id[u]
id[u]代表点新的编号,
s
z
[
u
]
sz[u]
sz[u]表示以u为根的子树的结点数量,则其儿子节点编号为
[
i
d
[
u
]
,
i
d
[
u
]
+
s
z
[
u
]
−
1
]
[ id[u], id[u]+sz[u]-1 ]
[id[u],id[u]+sz[u]−1]。
对另一棵树DFS,设当前节点为u,则将
i
d
[
u
]
id[u]
id[u]更新为1,然后访问u每个子节点,在子节点递归结束时,将子节点存储点信息的线段树与当前u对应的线段树启发式合并。所有子节点遍历完成后,访问当前线段树,
a
n
s
+
=
q
u
e
r
y
(
i
d
[
u
]
,
i
d
[
u
]
+
s
z
[
u
]
−
1
)
ans+=query(id[u], id[u]+sz[u]-1)
ans+=query(id[u],id[u]+sz[u]−1)。
#include<stdio.h>
#include<iostream>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<map>
#include<vector>
#include<queue>
#include<iterator>
#define dbg(x) cout<<#x<<" = "<<x<<endl;
#define INF 0x3f3f3f3f
#define eps 1e-7
using namespace std;
typedef long long LL;
typedef pair<int, int> P;
const int maxn = 100100;
const int mod = 1000000007;
struct node{
int l, r, sum;
}p[maxn*100];
int n, cnt, tot, rt[maxn], id[maxn], sz[maxn];
LL ans;
vector<int> g1[maxn], g2[maxn];
void init(int n);
int merge(int u, int v);
void dfs1(int u, int fa);
void dfs2(int u, int fa);
void bin(int l, int r, int k, int u);
void Update(int l, int r, int x, int k);
int query(int l, int r, int al, int ar, int k);
int main()
{
int i, j, k;
cnt = 1;
scanf("%d", &n);
for(i=1;i<n;i++){
scanf("%d %d", &j, &k);
g1[j].push_back(k);
g1[k].push_back(j);
}
for(i=1;i<n;i++){
scanf("%d %d", &j, &k);
g2[j].push_back(k);
g2[k].push_back(j);
}
dfs1(1, 0);
init(n);
dfs2(1, 0);
printf("%lld\n", ans);
return 0;
}
void init(int n)
{
for(int i=0;i<=n;i++)
p[i].l = p[i].r = p[i].sum = 0, rt[i] = i;
tot = n+4;
ans = 0;
}
void dfs1(int u, int fa)
{
sz[u] = 1;
id[u] = cnt++;
for(int i=0;i<g1[u].size();i++)
if(g1[u][i] != fa){
dfs1(g1[u][i], u);
sz[u] += sz[g1[u][i]];
}
}
void dfs2(int u, int fa)
{
Update(1, n, id[u], rt[u]);
for(int i=0;i<g2[u].size();i++)
if(g2[u][i] != fa){
dfs2(g2[u][i], u);
rt[u] = merge(rt[u], rt[g2[u][i]]);
}
ans += query(1, n, id[u], id[u]+sz[u]-1, rt[u])-1;
}
int merge(int u, int v)
{
if(p[u].sum<p[v].sum)swap(u, v);
bin(1, n, v, u);
return u;
}
void bin(int l, int r, int k, int u)
{
if(l == r){
Update(1, n, l, u);
return;
}
int mid = (l+r)/2;
if(p[p[k].l].sum)bin(l, mid, p[k].l, u);
if(p[p[k].r].sum)bin(mid+1, r, p[k].r, u);
}
void Update(int l, int r, int x, int k)
{
p[k].sum++;
if(l == r)return;
int mid = (l+r)/2;
if(x <= mid){
if(p[k].l == 0){
p[tot].l = p[tot].r = p[tot].sum = 0;
p[k].l = tot++;
}
Update(l, mid, x, p[k].l);
}else{
if(p[k].r == 0){
p[tot].l = p[tot].r = p[tot].sum = 0;
p[k].r = tot++;
}
Update(mid+1, r, x, p[k].r);
}
}
int query(int l, int r, int al, int ar, int k)
{
if(l == al && r == ar)return p[k].sum;
int mid = (l+r)/2;
if(ar <= mid)return query(l, mid, al, ar, p[k].l);
else if(al > mid)return query(mid+1, r, al, ar, p[k].r);
else return query(l, mid, al, mid, p[k].l)+query(mid+1, r, mid+1, ar, p[k].r);
}