distances sum
这道题要考虑两颗树,因为不可能同时考虑两颗树,不妨先处理好一颗树,在搜索第二颗树的时候顺便更新答案
——————————————————————————————————————
具
体
地
,
我
们
先
d
f
s
一
一
棵
树
,
记
录
来
到
这
个
点
和
离
开
这
个
点
的
d
f
s
序
(
这
样
方
便
计
算
子
树
的
区
间
)
,
然
后
用
来
到
这
的
d
f
s
序
建
立
树
状
数
组
(
当
然
也
可
以
用
线
段
树
)
,
这
样
,
我
们
就
可
以
通
过
树
状
数
组
知
道
一
棵
树
的
子
树
中
有
多
少
个
点
具体地,我们先dfs一一棵树,记录来到这个点和离开这个点的dfs序(这样方便计算子树的区间),然后用来到这的dfs序建立树状数组(当然也可以用线段树),这样,我们就可以通过树状数组知道一棵树的子树中有多少个点
具体地,我们先dfs一一棵树,记录来到这个点和离开这个点的dfs序(这样方便计算子树的区间),然后用来到这的dfs序建立树状数组(当然也可以用线段树),这样,我们就可以通过树状数组知道一棵树的子树中有多少个点
接
下
来
,
我
们
d
f
s
另
一
颗
树
,
找
到
一
个
节
点
s
,
先
计
算
目
前
s
中
有
多
少
个
是
第
一
颗
树
中
s
的
后
代
,
记
为
s
u
m
1
,
因
为
在
扫
描
到
s
前
已
经
找
过
这
些
点
,
所
以
这
些
点
一
定
不
可
能
是
这
棵
树
上
s
的
后
代
。
然
后
我
们
先
把
s
点
加
进
树
状
数
组
,
然
后
再
d
f
s
接下来,我们dfs另一颗树,找到一个节点s,先计算目前s中有多少个是第一颗树中s的后代,记为sum1,因为在扫描到s前已经找过这些点,所以这些点一定不可能是这棵树上s的后代。然后我们先把s点加进树状数组,然后再dfs
接下来,我们dfs另一颗树,找到一个节点s,先计算目前s中有多少个是第一颗树中s的后代,记为sum1,因为在扫描到s前已经找过这些点,所以这些点一定不可能是这棵树上s的后代。然后我们先把s点加进树状数组,然后再dfs
离
开
s
点
时
我
们
也
要
再
记
录
s
中
有
多
少
个
第
一
颗
树
中
s
的
后
代
,
记
为
s
u
m
2
,
然
后
两
个
s
u
m
3
=
s
u
m
1
−
s
u
m
2
−
1
即
为
同
时
再
s
子
树
的
点
数
(
即
再
进
s
之
后
出
s
点
之
前
的
差
,
还
要
减
一
个
1
是
因
为
多
将
s
加
入
了
树
状
数
组
,
而
s
已
经
算
过
贡
献
了
)
,
然
后
a
n
s
+
=
(
s
u
m
3
−
1
)
∗
s
u
m
3
就
是
答
案
离开s点时我们也要再记录s中有多少个第一颗树中s的后代,记为sum2,然后两个sum3=sum1-sum2-1即为同时再s子树的点数(即再进s之后出s点之前的差,还要减一个1是因为多将s加入了树状数组,而s已经算过贡献了),然后ans+=(sum3-1)*sum3就是答案
离开s点时我们也要再记录s中有多少个第一颗树中s的后代,记为sum2,然后两个sum3=sum1−sum2−1即为同时再s子树的点数(即再进s之后出s点之前的差,还要减一个1是因为多将s加入了树状数组,而s已经算过贡献了),然后ans+=(sum3−1)∗sum3就是答案
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
typedef long long ll;
const int N=100010;
int in[N],cnt=0,n,head[N];
struct edge{
int v,link;
}q[N<<1];
void put(int x,int y){
q[++cnt].v=y;
q[cnt].link=head[x];
head[x]=cnt;
}
int lowbit(int x){
return (x&(-x));
}
int tree[N],start[N],endd[N],tnt=0;
void dfs1(int s,int fa){
start[s]=++tnt;
for(int i=head[s];i;i=q[i].link){
int v=q[i].v;
if(v==fa) continue;
dfs1(v,s);
}
endd[s]=tnt;
}
ll ansout=0;
ll cul(int id){
ll ans=0;
for(int i=id;i;i-=lowbit(i)) ans+=tree[i];
return ans;
}
void add(int id){
for(int i=id;i<=n;i+=lowbit(i)) tree[i]++;
}
void dfs2(int s,int fa){
ll sum1=cul(endd[s])-cul(start[s]-1);
add(start[s]);
for(int i=head[s];i;i=q[i].link){
int v=q[i].v;
if(v==fa) continue;
dfs2(v,s);
}
ll sum2=cul(endd[s])-cul(start[s]-1);
ll sum3=sum2-sum1-1;
ansout+=(sum3)*(sum3-1)/2;
}
int main(){
scanf("%d",&n);
memset(in,0,sizeof(in));
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
put(u,v),put(v,u);
in[v]++;
}
int root;
for(root=1;in[root];root++);
dfs1(root,0);
memset(in,0,sizeof(in));
memset(q,0,sizeof(q));
memset(head,0,sizeof(head));
cnt=0;
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
put(u,v),put(v,u);
in[v]++;
}
for(root=1;in[root];root++);
dfs2(root,0);
printf("%lld",ansout);
}