Description
Solution
老年选手表示真的不太码的动。。。
细节太多了,相当难写
为啥XJ和ASDFZ的大佬们动不动一写就是6K,7K 300行啊,这也太能码了
这题感觉还是蛮套路的。
我们首先可以将边的中点也建一个点,现在就只会在点上相交了。
分情况讨论
要么是两个起点深度相同,一起向上走,走到这两个起点的lca处相交
一种是一个起点在子树内,向上走,一个起点在子树外,向下走,在子树的根碰头。
第一种情况,我们可以用线段树合并的方式维护,合并到l=r时就可以直接乘起来计算答案了。
对于第二种情况,我们首先树链剖分,对于一条重链,向上走的路径和向下走的路径可以写成一次函数的形式,斜率为-1或1,横坐标为当前深度
那么我们就是要算斜率为-1的线段和斜率为1的线段的交点个数。
我们可以将坐标系旋转45度,那么就变成平行于x轴和y轴的线段了,采用扫描线+树状数组计算。
注意我们在对于一条路径插入经过的所有重链的时候,我们需要强行将lca设为上行的直线(即斜率为-1)
时间复杂度
O
(
n
log
2
n
)
O(n\log ^2n)
O(nlog2n)
Code
#include <bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fod(i,a,b) for(int i=a;i>=b;--i)
#define N 400005
#define LL long long
using namespace std;
int n,m,m1,fs[N],n2,nt[2*N],dt[2*N],f[N][20],a[N][4];
int dep[N],dfn[N],sz[N],ti,t[10*N][2],n1,rt[N],son[N],top[N];
LL sm[10*N],ans;
int c[N];
struct node
{
int x,y,h;
friend bool operator <(node x,node y)
{
return x.h<y.h;
}
}d[N];
int lowbit(int x)
{
return x&(-x);
}
void put(int x,int v)
{
while(x<=2*n) c[x]+=v,x+=lowbit(x);
}
int get(int x)
{
int s=0;
while(x) s+=c[x],x-=lowbit(x);
return s;
}
vector<int> pt[N],pt2[N];
vector<node> ts[N][2];
void link(int x,int y)
{
nt[++m1]=fs[x];
dt[fs[x]=m1]=y;
}
void dfs(int k,int fa)
{
f[k][0]=fa;
sz[k]=1;
dep[k]=dep[fa]+1;
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa) dfs(p,k),sz[k]+=sz[p],son[k]=(sz[p]>sz[son[k]])?p:son[k];
}
}
void nwp(int &x)
{
if(!x) x=++n1;
}
void ins(int k,int l,int r,int x,int v)
{
if(l==r) sm[k]+=v;
else
{
int mid=(l+r)>>1;
if(x<=mid) nwp(t[k][0]),ins(t[k][0],l,mid,x,v);
else nwp(t[k][1]),ins(t[k][1],mid+1,r,x,v);
sm[k]=sm[t[k][0]]+sm[t[k][1]];
}
}
void merge(int &k,int x,int l,int r)
{
if(!k) {k=x;return;}
if(!x||!sm[x]) return;
if(l==r) ans+=sm[k]*sm[x],sm[k]+=sm[x];
else
{
int mid=(l+r)>>1;
merge(t[k][0],t[x][0],l,mid);
merge(t[k][1],t[x][1],mid+1,r);
sm[k]=sm[t[k][0]]+sm[t[k][1]];
}
}
void make(int k,int fa)
{
dfn[++dfn[0]]=k;
if(son[k]) top[son[k]]=top[k],make(son[k],k),rt[k]=rt[son[k]];
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=son[k]&&p!=fa)
{
top[p]=p;
make(p,k);
merge(rt[k],rt[p],1,n);
}
}
if(!rt[k]) rt[k]=++n1;
int l=pt[k].size();
ans=(ans+(LL)l*(LL)(l-1)/2);
ins(rt[k],1,n,dep[k],l);
int r=pt2[k].size();
fo(j,0,r-1) ins(rt[k],1,n,dep[pt2[k][j]],-1);
}
int lca(int x,int y)
{
if(dep[x]>dep[y]) swap(x,y);
for(int j=dep[y]-dep[x],c=0;j;j>>=1,c++) if(j&1) y=f[y][c];
for(int j=18;x!=y;)
{
while(j&&f[x][j]==f[y][j]) j--;
x=f[x][j],y=f[y][j];
}
return x;
}
bool in(int y,int x)
{
return (dfn[x]<=dfn[y]&&dfn[y]<dfn[x]+sz[x]);
}
void push(int i)
{
int x=a[i][0],y=a[i][1],p=a[i][2],l=dep[a[i][0]],r=a[i][3]-dep[a[i][1]];
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]])
ts[top[x]][0].push_back((node){l-2*dep[x],l-2*dep[top[x]],l}),x=f[top[x]][0];
else
ts[top[y]][1].push_back((node){r+2*dep[top[y]],r+2*dep[y],r}),y=f[top[y]][0];
}
if(dep[x]>dep[y])
ts[top[x]][0].push_back((node){l-2*dep[x],l-2*dep[y],l});
else
{
ts[top[x]][0].push_back((node){l-2*dep[x],l-2*dep[x],l});
if(y!=x) ts[top[y]][1].push_back((node){r+2*(dep[x]+1),r+2*dep[y],r});
}
}
int main()
{
cin>>n;
n2=n;
fo(i,1,n-1)
{
int x,y;
scanf("%d%d",&x,&y);
n2++;
link(x,n2),link(n2,x);
link(n2,y),link(y,n2);
}
n=n2;
dfs(1,0);
fo(j,1,18) fo(i,1,n) f[i][j]=f[f[i][j-1]][j-1];
cin>>m;
fo(i,1,m)
{
scanf("%d%d",&a[i][0],&a[i][1]);
a[i][2]=lca(a[i][0],a[i][1]);
a[i][3]=dep[a[i][0]]+dep[a[i][1]]-2*dep[a[i][2]];
pt[a[i][0]].push_back(a[i][0]);
pt2[a[i][2]].push_back(a[i][0]);
}
ans=0;
top[1]=1;
make(1,0);
fo(i,1,m)
push(i);
fo(i,1,n)
{
if(top[i]==i)
{
int l1=ts[i][0].size(),l2=ts[i][1].size(),le=0;
fo(j,0,l2-1)
{
node w=ts[i][1][j];
d[++le]=(node){w.h,1,w.x};
d[++le]=(node){w.h,-1,w.y+1};
}
sort(d+1,d+le+1);
sort(ts[i][0].begin(),ts[i][0].end());
int y=1;
fo(p1,0,l1-1)
{
int p=ts[i][0][p1].h;
while(y<=le&&d[y].h<=p) put(d[y].x+n,d[y].y),y++;
ans+=get(min(2*n,ts[i][0][p1].y+n))-get(max(0,ts[i][0][p1].x-1+n));
}
for(;y<=le;y++) put(d[y].x+n,d[y].y);
}
}
printf("%lld\n",ans);
}