Night的数据结构杂谈-虚树

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Night2002/article/details/79573571

在某些时候,我们需要维护树上选一些点所得到的东西。
这些东西要满足这样一个性质:未选的点可以通过某种方式删除而不影响最终的结果。
最典型的就是求被选出的节点在原树上的距离之和。
既然我们知道未选的点可以删掉,那么我们就想办法建一棵树,使得树上的未选点尽量少。
这棵树就叫虚树。
那么要怎么建立一棵虚树呢?
首先我们在原树上跑一遍 dfs,并得出树上节点的 dfs 序,记为 dfn
(顺便树链剖分维护一下 lca
然后我们按照这个 dfn 从小到大把节点插入虚树。
维护一个栈,它表示在当前的这棵虚树上,以最后一个插入的点为终点的 dfs 链。
设最后插入的点为 x(就是栈顶的点),当前要加入的点为 y。我们想把 y 插入到我们已经构建的虚树上去。
求出 lca(x,y),记为 Lca。有两种情况:
1. xy 分立在 Lca 的两棵子树下。
2. Lcax
/*为什么 Lca 不可能是 y 呢?
因为如果 Lcay,说明 dfn(Lca)=dfn(y)<dfn(x),而我们是按照dfs序号从小到大选点的,于是 dfn(x)<dfn(y),矛盾。*/
那么对于第二种情况,显然只需要把 y 连到 x 上面就行了。
对于第一种情况呢,有 dfn(y)>dfn(x)>dfn(Lca),那么这说明什么呢?
这说明我们已经把 Lca 所在的子树中,x 所在的子树全部遍历完了。
/*为什么遍历完了呢?
如果没有遍历完,那么肯定有一个未加入的点 k,满足 dfn(k)<dfn(y),我们按照 dfs 序号递增顺序遍历的话,肯定会把 k 加进来了才到 y。*/
这样,我们就直接构建 Lca 为根的,x 所在的那个子树。
由于我们的栈维护的是当前的 dfs 链,所以显然我们可以在退栈的时候连边,那么考虑一下不是退栈需要连的边。
x 所在的子树如果还有其它部分,它一定在之前就构建好了(所有退栈的点都已经被正确地连入树中了),就剩那条 dfs 链。
那么要如何正确地连 xLca 的边呢?
设栈顶的节点为 a,栈顶第二个节点为 b
重复以下操作:
如果 dfn(b)>dfn(Lca),可以直接连边 ba,然后退栈。
如果 dfn(b)==dfn(Lca),说明 b 就是 Lca,直接连边 Lcaa,此时子树已经构建完毕。
如果 dfn(b)<dfn(Lca),说明 Lcaab 夹在中间,此时连边 lcab,退一次栈,再把 Lca 入栈。
这样就连完了,接下来把 x 入栈即可。
好像很复杂对吧,我们观察一下这样连边的本质是什么。
/* 你快观察呀.jpg */
然后我们会发现,这么讨论太复杂了,我们直接利用它们在原树中的深度关系来连边即可。
首先我们得到一个点,还是记为 x
然后我们重复以下操作:
得到栈顶和 xlca,还是记为 Lca
并且把栈顶记为 a,栈顶第二个节点(如果有)记为 b
如果 depLca<depb,则连边 Lcab
否则如果depLca<depa,则连边 Lcaa
否则跳出。
接着如果 Lca 还不是栈顶的话,则把 Lca 入栈,然后把 x 入栈即可。
这样写起来简单些。
接下来我们看道例题:
计蒜客 青云的机房组网方案
给出一棵树,每个点有点权,边权均为 1
求所有点权互质的点对的距离和。
注意到本题可以转化为求 树上所有的点对距离之和 所有不互质的点对距离之和。
我们可以利用容斥原理来计算不互质的点对之和,这个步骤可以写个线性筛质数预处理。
然而对于每一个因数,我们在树上取的点并不多,且多次取的总和是 O(n) 级别的,又发现对于两个顶点,它们之间的距离不会因为这两点之间路径上点的多少而改变。因此就可以考虑对于每一次询问(一个因数相当于一个询问),我们根据原树的信息重新建一棵树,让这棵树里面尽量少包含未选择的节点。(于是这棵树就是虚树)然后在这棵虚树上跑一个树形 dp 就行了。
代码如下:

#include <bits/stdc++.h>
#define R register
#define LL long long
#define Max(__a,__b) (__a<__b?__b:__a)
#define Min(__a,__b) (__a<__b?__a:__b)
using namespace std;
template<class TT>inline void read(R TT &x){
    x=0;R bool f=false;R char c=getchar();
    for(;c<48||c>57;c=getchar())f|=(c=='-');
    for(;c>47&&c<58;c=getchar())x=(x<<1)+(x<<3)+(c^48);
    (f)&&(x=-x);
}
template<class orzyrt>inline orzyrt Abs(R orzyrt x){
    if(x<0)return -x;
    else return x;
}
int n;
namespace non_baoli{
    #define N 100010
    int mul[N];
    char com[N];
    int pri[N];
    inline void get_prime(R int cnt=0){
        mul[1]=1;
        for(R int i=2;i<100001;++i){
            if(!com[i])pri[cnt++]=i,mul[i]=-1;
            for(R int j=0;j<cnt&&i*pri[j]<100001;++j){
                com[i*pri[j]]=1;
                if(i%pri[j]==0){
                    mul[i*pri[j]]=0;
                    break;
                }else mul[i*pri[j]]=-mul[i];
            }
        }
    }
    struct Edge{
        int to;
        Edge *next;
    }E[N<<1],E1[N<<1],*head[N],*head1[N],*e=E,*st=E1;
    inline void add(R int u,R int v){
        *e=(Edge){v,head[u]};head[u]=e++;
    }
    inline void add1(R int u,R int v){
        *st=(Edge){v,head1[u]};head1[u]=st++;
    }
    int fa[N],son[N],dep[N],siz[N],top[N],dfn[N],dfs_clo;
    void dfs1(R int u,R int f){
        dfn[u]=++dfs_clo;
        siz[u]=1;fa[u]=f;
        dep[u]=dep[f]+1;
        R int v;
        for(R Edge *i=head[u];i;i=i->next){
            if((v=i->to)==f)continue;
            dfs1(v,u);
            siz[u]+=siz[v];
            if(siz[son[u]]<siz[v])son[u]=v;
        }
    }
    void dfs2(R int u,R int tp){
        top[u]=tp;
        if(son[u])dfs2(son[u],tp);
        for(R Edge*i=head[u];i;i=i->next){
            R int v=i->to;
            if(v!=fa[u]&&v!=son[u])dfs2(v,v);
        }
    }
    inline int lca(R int a,R int b){
        while(top[a]!=top[b]){
            dep[top[a]]>dep[top[b]]?
                a=fa[top[a]]:b=fa[top[b]];
        }
        return dep[a]<dep[b]?a:b;
    }
    int a[N],Top,cnt,sz[N],stk[N];
    LL val;
    void dfs3(R int u,R int f){
        for(R Edge *i=head[u];i;i=i->next){
            if(i->to!=f){
                dfs3(i->to,u);
                sz[u]+=sz[i->to];
            }
        }
        if(f)val+=1ll*Abs(dep[u]-dep[f])*sz[u]*(cnt-sz[u]);
    }
    void del(R int u,R int f){
        for(R Edge *i=head[u];i;i=i->next){
            if(i->to!=f)del(i->to,u);
        }
        sz[u]=0;
        head[u]=0;
    }
    inline bool cmp(R int a,R int b){return dfn[a]<dfn[b];}
    inline LL solve(R int num){
        cnt=Top=0;
        for(R int u=num;u<100001;u+=num){
            for(R Edge *i=head1[u];i;i=i->next){
                a[cnt++]=i->to;
            }
        }
        if(cnt<=1)return 0;
        e=E;
        val=0;
        sort(a,a+cnt,cmp);
        for(R int i=0;i<cnt;++i)sz[a[i]]=1;
        for(R int i=0,Lca,now;i<cnt;++i){
            Lca=0;now=a[i];
            while(Top>0){
                Lca=lca(now,stk[Top]);
                if(Top>1&&dep[Lca]<dep[stk[Top-1]]){
                    R int u=stk[Top],v=stk[Top-1];
                    add(u,v);add(v,u);
                    Top--;
                }else if(dep[Lca]<dep[stk[Top]]){
                    R int u=Lca,v=stk[Top];
                    add(u,v);add(v,u);
                    Top--;
                    break;
                }else break;
            }
            if(stk[Top]!=Lca)stk[++Top]=Lca;
            stk[++Top]=now;
        }
        while(Top>1){
            R int u=stk[Top],v=stk[Top-1];
            add(u,v);add(v,u);
            Top--;
        }
        dfs3(a[0],0);
        del(a[0],0);
        return val*mul[num];
    }
    inline void work(){
        get_prime();
        for(R int i=1,x;i<=n;++i){
            read(x);
            add1(x,i);
        }
        for(R int i=1,u,v;i<n;++i){
            read(u);read(v);
            add(u,v);add(v,u);
        }
        dfs_clo=0;
        dfs1(1,0);
        dfs2(1,1);
        R LL ans=0;
        memset(head,0,sizeof head);
        for(R int i=1;i<100001;++i){
            if(mul[i])ans+=solve(i);
        }
        printf("%lld\n",ans);
    }
}

int main(){
    read(n);
    non_baoli::work();
    return 0;
}
展开阅读全文

没有更多推荐了,返回首页