题目
n(n<=5e4)个点的树,树上共有k(k<=10)种颜色,
每个点一种颜色ai(1<=ai<=10),
统计这样的(u,v),满足u到v的路径上的点的颜色的并集恰为k种颜色,
注意(v,u)和(u,v)是不同的答案
思路来源
https://blog.csdn.net/albertluf/article/details/81388596
题解
做点分治,直接对u搜到的点统计贡献,减掉非跨过u的全部位于子树v内的贡献即可
求得的答案也不用减掉自己,也不用除以2,就是最终的答案
因为一条到u的链也可以构成答案,
复杂度貌似O(n*logn+n*k*2^k),感觉很玄学的样子,也不知道是怎么过去的…
后记,和梁神学了个按位分块,就是二进制位的后半段下放到子集里,
前半段统计的时候统计其超集的答案,这样能统计到所有的超集
复杂度大概O(n*logn*2^(k/2))的样子,跑得飞快
代码1
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=5e4+10,M=(1<<10)+5;
int head[N],cnt;
struct edge{int v,nex;}e[2*N];
void add(int u,int v){e[++cnt]=edge{v,head[u]};head[u]=cnt;}
bool vis[N];
int n,k,a[N],r,u,v;
int siz,f[N],sz[N],rt;
ll res,q[N],d[N],now[M],w;
void init(int n){
cnt=0;
for(int i=1;i<=n;++i){
vis[i]=head[i]=0;
}
}
//找下一次的重心rt
void getrt(int u,int fa,bool op){
f[u]=0;sz[u]=1;
for(int i=head[u];i;i=e[i].nex){
int v=e[i].v;
if(v==fa||vis[v])continue;
getrt(v,u,op);
f[u]=max(f[u],sz[v]);
sz[u]+=sz[v];
}
if(op){
f[u]=max(f[u],siz-sz[u]);
if(f[u]<f[rt])rt=u;
}
}
//计算重心u到子树内每个点的距离
void getdis(int u,int fa){
q[++r]=d[u];
for(int i=head[u];i;i=e[i].nex){
int v=e[i].v;
if(v==fa||vis[v])continue;
d[v]=d[u]|(1<<a[v]);
getdis(v,u);
}
}
//计算以u为根的子树的答案
ll cal(int u,int col){
r=0;d[u]=col|(1<<a[u]);
getdis(u,0);
ll ans=0;
int up=(1<<k)-1;
for(int j=0;j<=up;++j){
now[j]=0;
}
for(int i=1;i<=r;++i){
now[q[i]]++;
}
for(int i=0;i<k;++i){
for(int j=0;j<=up;++j){
if(!(j>>i&1)){
now[j]+=now[j|(1<<i)];
}
}
}
for(int j=1;j<=r;++j){
ans+=now[up^q[j]];
}
return ans;
}
void dfs(int u){
//每次用在u的子树里任取减去在v的子树里的答案
//每次只计算 必经过u的答案
res+=cal(u,0);
vis[u]=1;
for(int i=head[u];i;i=e[i].nex){
int v=e[i].v;
if(vis[v])continue;
res-=cal(v,1<<a[u]);
getrt(v,u,0);//获得正确的sz[v]
siz=sz[v];rt=0;
getrt(v,u,1);
dfs(rt);
}
}
int main(){
while(~scanf("%d%d",&n,&k)){
init(n);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
a[i]--;
}
for(int i=1;i<n;++i){
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
res=0;
f[0]=siz=n;rt=0;
getrt(1,0,1),dfs(rt);
printf("%lld\n",res);
}
return 0;
}
代码2
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=5e4+10,M=(1<<10)+5;
int head[N],cnt;
struct edge{int v,nex;}e[2*N];
void add(int u,int v){e[++cnt]=edge{v,head[u]};head[u]=cnt;}
bool vis[N];
int n,k,a[N],r,u,v;
int siz,f[N],sz[N],rt;
ll res,q[N],d[N],now[M],w;
void init(int n){
cnt=0;
for(int i=1;i<=n;++i){
vis[i]=head[i]=0;
}
}
//找下一次的重心rt
void getrt(int u,int fa,bool op){
f[u]=0;sz[u]=1;
for(int i=head[u];i;i=e[i].nex){
int v=e[i].v;
if(v==fa||vis[v])continue;
getrt(v,u,op);
f[u]=max(f[u],sz[v]);
sz[u]+=sz[v];
}
if(op){
f[u]=max(f[u],siz-sz[u]);
if(f[u]<f[rt])rt=u;
}
}
//计算重心u到子树内每个点的距离
void getdis(int u,int fa){
q[++r]=d[u];
for(int i=head[u];i;i=e[i].nex){
int v=e[i].v;
if(v==fa||vis[v])continue;
d[v]=d[u]|(1<<a[v]);
getdis(v,u);
}
}
//计算以u为根的子树的答案
ll cal(int u,int col){
r=0;d[u]=col|(1<<a[u]);
getdis(u,0);
ll ans=0;
int up=(1<<k)-1,suf=(1<<(k/2))-1,pre=up^suf,ppre=pre>>(k/2);
for(int j=0;j<=up;++j){
now[j]=0;
}
for(int i=1;i<=r;++i){//后半段下放到子集上
int fro=q[i]&pre,beh=q[i]&suf;
for(int j=0;j<=suf;++j){
if((beh&j)==j){
now[fro|j]++;
}
}
}
for(int i=1;i<=r;++i){//前半段统计超集的答案
q[i]=(up^q[i]);//补集的超集
int fro=(q[i]&pre)>>(k/2),beh=q[i]&suf;
for(int j=0;j<=ppre;++j){
if((fro|j)==j){
ans+=now[(j<<(k/2))|beh];
}
}
}
return ans;
}
void dfs(int u){
//每次用在u的子树里任取减去在v的子树里的答案
//每次只计算 必经过u的答案
res+=cal(u,0);
vis[u]=1;
for(int i=head[u];i;i=e[i].nex){
int v=e[i].v;
if(vis[v])continue;
res-=cal(v,1<<a[u]);
getrt(v,u,0);//获得正确的sz[v]
siz=sz[v];rt=0;
getrt(v,u,1);
dfs(rt);
}
}
int main(){
while(~scanf("%d%d",&n,&k)){
init(n);
for(int i=1;i<=n;++i){
scanf("%d",&a[i]);
a[i]--;
}
for(int i=1;i<n;++i){
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
res=0;
f[0]=siz=n;rt=0;
getrt(1,0,1),dfs(rt);
printf("%lld\n",res);
}
return 0;
}