有一个庞大的家族,共n人。已知这n个人的祖辈关系正好形成树形结构(即父亲向儿子连边)。
在另一个未知的平行宇宙,这n人的祖辈关系仍然是树形结构,但他们相互之间的关系却完全不同了,原来的祖先可能变成了后代,后代变成的同辈……
两个人的亲密度定义为在这两个平行宇宙有多少人一直是他们的公共祖先。
整个家族的亲密度定义为任意两个人亲密度的总和。
Input
第一行一个数n(1<=n<=100000)
接下来n-1行每行两个数x,y表示在第一个平行宇宙x是y的父亲。
接下来n-1行每行两个数x,y表示在第二个平行宇宙x是y的父亲。
Output
一个数,表示整个家族的亲密度。
Input示例
5
1 3
3 5
5 4
4 2
1 2
1 3
3 4
1 5
Output示例
6
分析:直接做好像没有办法,那么我们可以把树上的化为线段,用dfs序把前后两颗树化为区间。假设lca为x,那么x区间内的点集合为[l,r]那么我们只要用主席树统计在[l,r]中[l1,r1]出现的次数就可以了,l1r1是x在第二课树内子树的节点集合。
因为多打了一个等于号被坑10分钟。。
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
const int N=1e5+10;
typedef long long ll;
struct node
{
int l,r;
ll v;
}t[N*40];
int root[N],head[N*2],go[N*2],next[N*2];
int eda[N],edb[N],dfa[N],dfb[N],ta[N],tb[N],tot=0,cnt=0,in[N],n;
inline void add(int x,int y)
{
go[++tot]=y;
next[tot]=head[x];
head[x]=tot;
}
inline void clr()
{
memset(in,0,sizeof(in));
memset(head,0,sizeof(head));
tot=0;
cnt=0;
}
inline void dfs(int x,int fa,int flag)
{
if (!flag)dfa[x]=++cnt,ta[x]=cnt;
else dfb[x]=++cnt,tb[cnt]=ta[x];
int i=head[x];
while (i)
{
int v=go[i];
if (v!=fa)dfs(v,x,flag);
i=next[i];
}
if (!flag)eda[x]=cnt;
else edb[x]=cnt;
}
int toa=0;
inline void build(int &x,int l,int r)
{
x=toa++;
t[x].v=0;
if (l==r)return;
int mid=(l+r)/2;
build(t[x].l,l,mid);
build(t[x].r,mid+1,r);
}
inline void ins(int &x,int y,int pos,int l,int r)
{
x=toa++;
if (l==r)
{
t[x].v=t[y].v+1;
return;
}
t[x].l=t[y].l;
t[x].r=t[y].r;
int mid=(l+r)/2;
if (pos<=mid)ins(t[x].l,t[y].l,pos,l,mid);
else ins(t[x].r,t[y].r,pos,mid+1,r);
t[x].v=t[t[x].l].v+t[t[x].r].v;
}
inline ll query(int x,int y,int l1,int r1,int l,int r)
{
if (l>=l1&&r<=r1)return t[y].v-t[x].v;
int mid=(l+r)/2;
ll ans=0;
if (l1<=mid)ans+=query(t[x].l,t[y].l,l1,r1,l,mid);
if (r1>mid)ans+=query(t[x].r,t[y].r,l1,r1,mid+1,r);
return ans;
}
inline void cal(int x)
{
clr();
fo(i,1,n-1)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
in[y]++;
}
fo(i,1,n)
if (!in[i])
{
dfs(i,-1,x);
break;
}
}
int main()
{
scanf("%d",&n);
toa=0;
cal(0);
cal(1);
cnt=0;
build(root[0],1,n);
fo(i,1,n)ins(root[i],root[i-1],tb[i],1,n);
ll ans=0;
fo(i,1,n)
{
int l=dfb[i],r=edb[i];
ll tmp=query(root[l],root[r],dfa[i],eda[i],1,n);
ans+=1ll*(tmp-1)*tmp/2;
}
printf("%lld\n",ans);
}