传送门
解析:
OJ又双叒叕卡栈空间。。。
思路:
首先看到距离不超过kkk很容易想到点分治。
每次处理出子树中所有点到分治中心的距离,然后处理一个二进制前缀和,双指针扫一遍数列,就可以愉快的做完这道题?
点分治复杂度O(nlognlog∣A∣)O(nlognlog|A|)O(nlognlog∣A∣),卡在2e82e82e8的极限,本来递归算法常数就大,卡不动了。。。
但是愿意卡常数的话考场上还是有85pts85pts85pts的
就算懒得卡,也有70pts70pts70pts可以拿。
考虑一个所谓复杂度不稳定的算法:长链剖分。
维护每条长链的二进制后缀和,合并两条长链可以做到O(Len log∣A∣)O(Len\text{ }log|A|)O(Len log∣A∣),并且每条长链只会被合并到其他长链一次,所以总的复杂度O(nlog∣A∣)O(nlog|A|)O(nlog∣A∣)。
考虑怎么在合并的同时统计答案。
已经维护了后缀和了,那么可以对较短链上的每个点都在较长链上找一下答案,其实每个点只会被询问一次所以这里的复杂度仍然是O(nlog∣A∣)O(nlog|A|)O(nlog∣A∣)。
就可以愉快的水过这道题了。
代码(点分治):
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
#define gc getchar
#define pc putchar
#define cs const
inline int getint(){
re int num;
re char c;
while(!isdigit(c=gc()));num=c^48;
while(isdigit(c=gc()))num=(num<<1)+(num<<3)+(c^48);
return num;
}
cs int N=500005,logN=20;
int last[N],nxt[N<<1],to[N<<1],ecnt;
inline void addedge(cs int &u,cs int &v){
nxt[++ecnt]=last[u],last[u]=ecnt,to[ecnt]=v;
nxt[++ecnt]=last[v],last[v]=ecnt,to[ecnt]=u;
}
int n,L;
int val[N],siz[N];
bool ban[N];
int total,mxsiz,G;
inline void find_G(cs int &u,cs int &fa){
siz[u]=1;re int mx=1;
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
if(ban[v]||v==fa)continue;
find_G(v,u);siz[u]+=siz[v];
mx=max(mx,siz[v]);
}
mx=max(mx,total-siz[u]);
if(mx<=mxsiz)mxsiz=mx,G=u;
}
pair<int,int> dist[N];
int tail;
inline void dfs(cs int &u,cs int &fa,cs int &dis){
if(dis>L)return ;
dist[++tail]=make_pair(dis,val[u]);
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
if(v==fa||ban[v])continue;
dfs(v,u,dis+1);
}
}
int sum[23],tot;
inline void add(cs int &val){
for(int re i=logN;~i;--i){
sum[i]+=(val>>i)&1;
}++tot;
}
inline ll query(cs int &val){
if(tot==0)return 0;
ll res=0;
for(int re i=logN;~i;--i){
if(val&(1<<i))res+=(tot-sum[i])*1ll<<i;
else res+=sum[i]*1ll<<i;
}
return res;
}
inline ll calc(cs int &u,cs int &dis){
tail=0;
dfs(u,u,dis);
sort(dist+1,dist+tail+1);
ll ans=0;
memset(sum,0,sizeof sum);tot=0;
for(int re r=tail,l=0;r>l;--r){
while(l+1<r&&dist[l+1].first+dist[r].first<=L){
ans+=query(dist[l+1].second);
add(dist[l+1].second);
++l;
}
ans+=query(dist[r].second);
}
return ans;
}
ll ans;
inline void solve(int u){
ans+=calc(u,0);
ban[u]=true;
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
if(ban[v])continue;
ans-=calc(v,1);
mxsiz=total=siz[v];
find_G(v,u);
solve(G);
}
}
signed main(){
int size=1<<25;
__asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));
n=getint();
L=getint();
for(int re i=1;i<=n;++i){
val[i]=getint();
}
for(int re i=1;i<n;++i){
int u=getint(),v=getint();
addedge(u,v);
}
total=mxsiz=n;
find_G(1,1);
solve(G);
cout<<ans;
exit(0);
}
代码(长链剖分):
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define re register
#define gc getchar
#define pc putchar
#define cs const
inline int getint(){
re int num;
re char c;
while(!isdigit(c=gc()));num=c^48;
while(isdigit(c=gc()))num=(num<<1)+(num<<3)+(c^48);
return num;
}
cs int N=500005;
int last[N],nxt[N<<1],to[N<<1],ecnt;
inline void addedge(int u,int v){
nxt[++ecnt]=last[u],last[u]=ecnt,to[ecnt]=v;
nxt[++ecnt]=last[v],last[v]=ecnt,to[ecnt]=u;
}
struct node{
int cnt[20][2];
node(){memset(cnt,0,sizeof cnt);}
node operator=(cs int &a){
memset(cnt,0,sizeof cnt);
for(int re i=0;i<20;++i)cnt[i][(a>>i)&1]=1;
return *this;
}
node operator+(cs node &a)cs{
node tmp;
for(int re i=0;i<20;++i){
tmp.cnt[i][0]=cnt[i][0]+a.cnt[i][0];
tmp.cnt[i][1]=cnt[i][1]+a.cnt[i][1];
}
return tmp;
}
node operator+=(cs node &a){
*this=*this+a;
return *this;
}
node operator-(cs node &a){
node tmp;
for(int re i=0;i<20;++i){
tmp.cnt[i][0]=cnt[i][0]-a.cnt[i][0];
tmp.cnt[i][1]=cnt[i][1]-a.cnt[i][1];
}
return tmp;
}
ll operator*(cs node &a)cs{
ll res=0;
for(int re i=0;i<20;++i)
res+=(1ll*cnt[i][0]*a.cnt[i][1]+1ll*cnt[i][1]*a.cnt[i][0])<<i;
return res;
}
}val[N],b[N],*f[N];
int now=1,n,L;
int dep[N],son[N];
inline void dfs1(int u,int fa){
dep[u]=1;
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
if(v==fa)continue;
dfs1(v,u);
if(dep[v]+1>dep[u]){
son[u]=v;
dep[u]=dep[v]+1;
}
}
}
ll ans;
void init(int u){f[u]=b+now;now+=dep[u];}
node get(int u,int l){return (l<0)?f[u][0]:(l>=dep[u]?b[0]:f[u][l]);}
node calc(int u,int l,int r){return get(u,l)-get(u,r+1);}
void merge(int u,int v){
for(int re i=0;i<dep[v];++i)
ans+=calc(v,i,i)*calc(u,0,L-i-1);
for(int re i=0;i<dep[v];++i)
f[u][i+1]+=f[v][i];
f[u][0]+=f[v][0];
}
void dfs2(int u,int fa){
if(son[u]){
f[son[u]]=f[u]+1;
dfs2(son[u],u);
f[u][0]=f[son[u]][0]+val[u];
ans+=val[u]*calc(u,0,L);
}
else f[u][0]=val[u];
for(int re e=last[u],v=to[e];e;v=to[e=nxt[e]]){
if(v==fa||v==son[u])continue;
init(v);
dfs2(v,u);
merge(u,v);
}
}
signed main(){
int size=1<<27;
__asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));
n=getint();
L=getint();
for(int re i=1;i<=n;++i){
val[i]=getint();
}
for(int re i=1;i<n;++i){
int u=getint(),
v=getint();
addedge(u,v);
}
dfs1(1,0);
init(1);
dfs2(1,0);
cout<<ans;
exit(0);
}