[JZOJ6074]【GDOI2019模拟2019.3.20】铁路【数据结构】【树状数组】【树链剖分】【线段树合并】

Description

在这里插入图片描述
在这里插入图片描述

Solution

老年选手表示真的不太码的动。。。

细节太多了,相当难写
为啥XJ和ASDFZ的大佬们动不动一写就是6K,7K 300行啊,这也太能码了

这题感觉还是蛮套路的。

我们首先可以将边的中点也建一个点,现在就只会在点上相交了。

分情况讨论
要么是两个起点深度相同,一起向上走,走到这两个起点的lca处相交

一种是一个起点在子树内,向上走,一个起点在子树外,向下走,在子树的根碰头。

第一种情况,我们可以用线段树合并的方式维护,合并到l=r时就可以直接乘起来计算答案了。

对于第二种情况,我们首先树链剖分,对于一条重链,向上走的路径和向下走的路径可以写成一次函数的形式,斜率为-1或1,横坐标为当前深度

那么我们就是要算斜率为-1的线段和斜率为1的线段的交点个数。

我们可以将坐标系旋转45度,那么就变成平行于x轴和y轴的线段了,采用扫描线+树状数组计算。

注意我们在对于一条路径插入经过的所有重链的时候,我们需要强行将lca设为上行的直线(即斜率为-1)
时间复杂度 O ( n log ⁡ 2 n ) O(n\log ^2n) O(nlog2n)

Code

#include <bits/stdc++.h>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fod(i,a,b) for(int i=a;i>=b;--i)
#define N 400005
#define LL long long
using namespace std;
int n,m,m1,fs[N],n2,nt[2*N],dt[2*N],f[N][20],a[N][4];
int dep[N],dfn[N],sz[N],ti,t[10*N][2],n1,rt[N],son[N],top[N];
LL sm[10*N],ans;
int c[N];
struct node
{
	int x,y,h;
	friend bool operator <(node x,node y)
	{
		return x.h<y.h;
	}
}d[N];
int lowbit(int x)
{	
	return x&(-x);
}
void put(int x,int v)
{
	while(x<=2*n) c[x]+=v,x+=lowbit(x);
}
int get(int x)
{
	int s=0;
	while(x) s+=c[x],x-=lowbit(x);
	return s; 
}
vector<int> pt[N],pt2[N];
vector<node> ts[N][2];
void link(int x,int y)
{
	nt[++m1]=fs[x];
	dt[fs[x]=m1]=y;
}
void dfs(int k,int fa)
{
	f[k][0]=fa;
	sz[k]=1;
	dep[k]=dep[fa]+1;
	for(int i=fs[k];i;i=nt[i])
	{
		int p=dt[i];
		if(p!=fa) dfs(p,k),sz[k]+=sz[p],son[k]=(sz[p]>sz[son[k]])?p:son[k];
	}
}
void nwp(int &x)
{
	if(!x) x=++n1;
}
void ins(int k,int l,int r,int x,int v)
{
	if(l==r) sm[k]+=v;
	else
	{
		int mid=(l+r)>>1;
		if(x<=mid) nwp(t[k][0]),ins(t[k][0],l,mid,x,v);
		else nwp(t[k][1]),ins(t[k][1],mid+1,r,x,v);
		sm[k]=sm[t[k][0]]+sm[t[k][1]];
	}
}
void merge(int &k,int x,int l,int r)
{
	if(!k) {k=x;return;}
	if(!x||!sm[x]) return;
	if(l==r) ans+=sm[k]*sm[x],sm[k]+=sm[x];
	else
	{
		int mid=(l+r)>>1;
		merge(t[k][0],t[x][0],l,mid);
		merge(t[k][1],t[x][1],mid+1,r);
		sm[k]=sm[t[k][0]]+sm[t[k][1]];
	}
}
void make(int k,int fa)
{
	dfn[++dfn[0]]=k;
	if(son[k]) top[son[k]]=top[k],make(son[k],k),rt[k]=rt[son[k]];
	for(int i=fs[k];i;i=nt[i])
	{
		int p=dt[i];
		if(p!=son[k]&&p!=fa) 
		{
			top[p]=p;
			make(p,k);
			merge(rt[k],rt[p],1,n);
		}
	}
	if(!rt[k]) rt[k]=++n1;
	int l=pt[k].size();
	ans=(ans+(LL)l*(LL)(l-1)/2);
	ins(rt[k],1,n,dep[k],l);
	int r=pt2[k].size();
	fo(j,0,r-1) ins(rt[k],1,n,dep[pt2[k][j]],-1);
}
int lca(int x,int y)
{
	if(dep[x]>dep[y]) swap(x,y);
	for(int j=dep[y]-dep[x],c=0;j;j>>=1,c++) if(j&1) y=f[y][c];
	for(int j=18;x!=y;)
	{
		while(j&&f[x][j]==f[y][j]) j--;
		x=f[x][j],y=f[y][j];
	}
	return x;
}
bool in(int y,int x)
{
	return (dfn[x]<=dfn[y]&&dfn[y]<dfn[x]+sz[x]);
}
void push(int i)
{
	int x=a[i][0],y=a[i][1],p=a[i][2],l=dep[a[i][0]],r=a[i][3]-dep[a[i][1]];
	while(top[x]!=top[y])
	{
		if(dep[top[x]]>dep[top[y]])
			ts[top[x]][0].push_back((node){l-2*dep[x],l-2*dep[top[x]],l}),x=f[top[x]][0];
		else
			ts[top[y]][1].push_back((node){r+2*dep[top[y]],r+2*dep[y],r}),y=f[top[y]][0];
	}
	if(dep[x]>dep[y]) 
		ts[top[x]][0].push_back((node){l-2*dep[x],l-2*dep[y],l});
	else
	{
		ts[top[x]][0].push_back((node){l-2*dep[x],l-2*dep[x],l});
		if(y!=x) ts[top[y]][1].push_back((node){r+2*(dep[x]+1),r+2*dep[y],r});
	}
}
int main()
{
	cin>>n;
	n2=n;
	fo(i,1,n-1)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		n2++;
		link(x,n2),link(n2,x);
		link(n2,y),link(y,n2);
	}
	n=n2;
	dfs(1,0);
	fo(j,1,18) fo(i,1,n) f[i][j]=f[f[i][j-1]][j-1];
	cin>>m;
	fo(i,1,m) 
	{
		scanf("%d%d",&a[i][0],&a[i][1]);
		a[i][2]=lca(a[i][0],a[i][1]);
		a[i][3]=dep[a[i][0]]+dep[a[i][1]]-2*dep[a[i][2]];
		pt[a[i][0]].push_back(a[i][0]);
		pt2[a[i][2]].push_back(a[i][0]);
	}
	ans=0;
	top[1]=1;
	make(1,0);
	fo(i,1,m) 
		push(i);
	fo(i,1,n)
	{
		if(top[i]==i)
		{
			int l1=ts[i][0].size(),l2=ts[i][1].size(),le=0;
			fo(j,0,l2-1)
			{
				node w=ts[i][1][j];
				d[++le]=(node){w.h,1,w.x};
				d[++le]=(node){w.h,-1,w.y+1};
			}
			sort(d+1,d+le+1);
			sort(ts[i][0].begin(),ts[i][0].end());
			int y=1;
			fo(p1,0,l1-1)
			{
				int p=ts[i][0][p1].h;
				while(y<=le&&d[y].h<=p) put(d[y].x+n,d[y].y),y++;
				ans+=get(min(2*n,ts[i][0][p1].y+n))-get(max(0,ts[i][0][p1].x-1+n)); 
			}
			for(;y<=le;y++) put(d[y].x+n,d[y].y);
		}
	}
	printf("%lld\n",ans);
}
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值