2018.11.05【校内模拟】树(长链剖分)(复杂度多一个log(过不了)的点分治)...

传送门


解析:

OJ又双叒叕卡栈空间。。。

思路:

首先看到距离不超过kkk很容易想到点分治。

每次处理出子树中所有点到分治中心的距离,然后处理一个二进制前缀和,双指针扫一遍数列,就可以愉快的做完这道题?

点分治复杂度O(nlognlog∣A∣)O(nlognlog|A|)O(nlognlogA),卡在2e82e82e8的极限,本来递归算法常数就大,卡不动了。。。

但是愿意卡常数的话考场上还是有85pts85pts85pts
就算懒得卡,也有70pts70pts70pts可以拿。

考虑一个所谓复杂度不稳定的算法:长链剖分。

维护每条长链的二进制后缀和,合并两条长链可以做到O(Len log∣A∣)O(Len\text{ }log|A|)O(Len logA),并且每条长链只会被合并到其他长链一次,所以总的复杂度O(nlog∣A∣)O(nlog|A|)O(nlogA)

考虑怎么在合并的同时统计答案。

已经维护了后缀和了,那么可以对较短链上的每个点都在较长链上找一下答案,其实每个点只会被询问一次所以这里的复杂度仍然是O(nlog∣A∣)O(nlog|A|)O(nlogA)

就可以愉快的水过这道题了。


代码(点分治):

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
#define gc getchar
#define pc putchar
#define cs const

inline int getint(){
	re int num;
	re char c;
	while(!isdigit(c=gc()));num=c^48;
	while(isdigit(c=gc()))num=(num<<1)+(num<<3)+(c^48);
	return num;
}

cs int N=500005,logN=20;
int last[N],nxt[N<<1],to[N<<1],ecnt;
inline void addedge(cs int &u,cs int &v){
	nxt[++ecnt]=last[u],last[u]=ecnt,to[ecnt]=v;
	nxt[++ecnt]=last[v],last[v]=ecnt,to[ecnt]=u;
}

int n,L;
int val[N],siz[N];
bool ban[N];

int total,mxsiz,G;
inline void find_G(cs int &u,cs int &fa){
	siz[u]=1;re int mx=1;
	for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
		if(ban[v]||v==fa)continue;
		find_G(v,u);siz[u]+=siz[v];
		mx=max(mx,siz[v]);
	}
	mx=max(mx,total-siz[u]);
	if(mx<=mxsiz)mxsiz=mx,G=u;
}

pair<int,int> dist[N];
int tail;
inline void dfs(cs int &u,cs int &fa,cs int &dis){
	if(dis>L)return ;
	dist[++tail]=make_pair(dis,val[u]);
	for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
		if(v==fa||ban[v])continue;
		dfs(v,u,dis+1);
	}
}

int sum[23],tot;
inline void add(cs int &val){
	for(int re i=logN;~i;--i){
		sum[i]+=(val>>i)&1;
	}++tot;
}

inline ll query(cs int &val){
	if(tot==0)return 0;
	ll res=0;
	for(int re i=logN;~i;--i){
		if(val&(1<<i))res+=(tot-sum[i])*1ll<<i;
		else res+=sum[i]*1ll<<i;
	}
	return res;
}

inline ll calc(cs int &u,cs int &dis){
	tail=0;
	dfs(u,u,dis);
	sort(dist+1,dist+tail+1);
	ll ans=0;
	memset(sum,0,sizeof sum);tot=0;
	for(int re r=tail,l=0;r>l;--r){
		while(l+1<r&&dist[l+1].first+dist[r].first<=L){
			ans+=query(dist[l+1].second);
			add(dist[l+1].second);
			++l;
		}
		ans+=query(dist[r].second);
	}
	return ans;
}

ll ans;
inline void solve(int u){
	ans+=calc(u,0);
	ban[u]=true;
	for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
		if(ban[v])continue;
		ans-=calc(v,1);
		mxsiz=total=siz[v];
		find_G(v,u);
		solve(G);
	}
}

signed main(){
	int size=1<<25;
    __asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));

	n=getint();
	L=getint();
	for(int re i=1;i<=n;++i){
		val[i]=getint();
	}
	for(int re i=1;i<n;++i){
		int u=getint(),v=getint();
		addedge(u,v);
	}
	total=mxsiz=n;
	find_G(1,1);
	solve(G);
	cout<<ans;
	exit(0);
}

代码(长链剖分):

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
#define gc getchar
#define pc putchar
#define cs const

inline int getint(){
	re int num;
	re char c;
	while(!isdigit(c=gc()));num=c^48;
	while(isdigit(c=gc()))num=(num<<1)+(num<<3)+(c^48);
	return num;
}

cs int N=500005;
int last[N],nxt[N<<1],to[N<<1],ecnt;
inline void addedge(int u,int v){
	nxt[++ecnt]=last[u],last[u]=ecnt,to[ecnt]=v;
	nxt[++ecnt]=last[v],last[v]=ecnt,to[ecnt]=u; 
}

struct node{
	int cnt[20][2];
	node(){memset(cnt,0,sizeof cnt);}
	node operator=(cs int &a){
		memset(cnt,0,sizeof cnt);
		for(int re i=0;i<20;++i)cnt[i][(a>>i)&1]=1;
		return *this;
	}
	
	node operator+(cs node &a)cs{
		node tmp;
		for(int re i=0;i<20;++i){
			tmp.cnt[i][0]=cnt[i][0]+a.cnt[i][0];
			tmp.cnt[i][1]=cnt[i][1]+a.cnt[i][1];
		}
		return tmp;
	}
	
	node operator+=(cs node &a){
		*this=*this+a;
		return *this;
	}
	
	node operator-(cs node &a){
		node tmp;
		for(int re i=0;i<20;++i){
			tmp.cnt[i][0]=cnt[i][0]-a.cnt[i][0];
			tmp.cnt[i][1]=cnt[i][1]-a.cnt[i][1];
		}
		return tmp;
	}
	
	ll operator*(cs node &a)cs{
		ll res=0;
		for(int re i=0;i<20;++i)
		res+=(1ll*cnt[i][0]*a.cnt[i][1]+1ll*cnt[i][1]*a.cnt[i][0])<<i;
		return res;
	}
}val[N],b[N],*f[N];
int now=1,n,L;
int dep[N],son[N];

inline void dfs1(int u,int fa){
	dep[u]=1;
	for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
		if(v==fa)continue;
		dfs1(v,u);
		if(dep[v]+1>dep[u]){
			son[u]=v;
			dep[u]=dep[v]+1;
		}
	}
}

ll ans;

void init(int u){f[u]=b+now;now+=dep[u];}
node get(int u,int l){return (l<0)?f[u][0]:(l>=dep[u]?b[0]:f[u][l]);}
node calc(int u,int l,int r){return get(u,l)-get(u,r+1);}

void merge(int u,int v){
	for(int re i=0;i<dep[v];++i)
	ans+=calc(v,i,i)*calc(u,0,L-i-1);
	for(int re i=0;i<dep[v];++i)
	f[u][i+1]+=f[v][i];
	f[u][0]+=f[v][0];
}

void dfs2(int u,int fa){
	if(son[u]){
		f[son[u]]=f[u]+1;
		dfs2(son[u],u);
		f[u][0]=f[son[u]][0]+val[u];
		ans+=val[u]*calc(u,0,L);
	}
	else f[u][0]=val[u];
	for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
		if(v==fa||v==son[u])continue;
		init(v);
		dfs2(v,u);
		merge(u,v);
	}
}

signed main(){
	int size=1<<27;
    __asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));
	n=getint();
	L=getint();
	for(int re i=1;i<=n;++i){
		val[i]=getint();
	}
	for(int re i=1;i<n;++i){
		int u=getint(),
		v=getint();
		addedge(u,v);
	}
	dfs1(1,0);
	init(1);
	dfs2(1,0);
	cout<<ans;
	exit(0);
}

转载于:https://www.cnblogs.com/zxyoi/p/10047118.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值