首先求一遍树1的dfs序,线段树维护数组c[i]
dfs遍历树2
当进入u点时
①: subNum=∑out[u]i=in[u]c[i]
②: c[in[u]]+=1
③: 遍历u的所有孩子
④: 以u为树根的2个子树包含的公共点数=∑out[u]i=in[u]c[i]−subNum
#include<stdio.h>
#include<bits/stdc++.h>
#define ll long long
#define lowbit(x) ((x)&(-x))
using namespace std;
const int N = 1e5 +5 ;
int in[N],out[N];
int inDeg1[N],inDeg2[N];
struct BIT{
ll c[N];
int n;
void init(int n){
this->n=n;
fill(c,c+n+1,0);
}
void add(int x,ll val){
while(x<=n){
c[x]+=val;
x+=lowbit(x);
}
}
ll sum(int x){
ll ans=0;
while(x){
ans+=c[x];
x-=lowbit(x);
}
return ans;
}
}bit;
vector<int>G1[N],G2[N];
void init(int n){
for(int i=0;i<=n;++i){
G1[i].clear();
G2[i].clear();
inDeg1[i]=inDeg2[i]=0;
}
bit.init(n);
}
void dfs(int u,int&t){
in[u]=++t;
for(int v:G1[u]){
dfs(v,t);
}
out[u]=t;
}
ll f(int n){
return (ll)n*(n-1)/2;
}
ll ans=0;
void dfs2(int u){
ll subNum=bit.sum(out[u])-bit.sum(in[u]-1);
bit.add(in[u],1);
for(int v:G2[u]){
dfs2(v);
}
ll num=bit.sum(out[u])-bit.sum(in[u]-1)-subNum;
ans+=f(num-1);
}
ll slove(int n){
ans=0;
for(int i=1,t=0;i<=n;++i){
if(inDeg1[i]==0){
dfs(i,t);
}
}
for(int i=1;i<=n;++i){
if(inDeg2[i]==0){
dfs2(i);
}
}
return ans;
}
int main(){
int n;
while(~scanf("%d",&n)){
init(n);
for(int i=0;i<n-1;++i){
int u,v;
scanf("%d%d",&u,&v);
G1[u].push_back(v);
++inDeg1[v];
}
for(int i=0;i<n-1;++i){
int u,v;
scanf("%d%d",&u,&v);
G2[u].push_back(v);
++inDeg2[v];
}
printf("%lld\n",slove(n));
}
return 0;
}