SPOJ COT2 树上的莫队算法,树上区间查询

题意:n个节点形成的一棵树。每个节点有一个值。m次查询,求出(u,v)路径上出现了多少个不同的数。

树上的莫队算法,同样将树分成siz=sqrt(n)块,然后离线操作。先对树dfs一遍,每当子树节点个数num>=siz,就将这num个分成一块。读取所有的查询按左端点所在块排序。

重点在于怎么进行区间转移,对路径的lca特殊处理,参考博客http://blog.csdn.net/kuribohg/article/details/41458639  


用倍增法求lca单次要用logn复杂度,要跑3200ms。有个地方可以优化,就是知道了所有的查询,也就是事先知道了转移路径,可以用离线的方法求O(n)求出所有需要用到的lca,这个写起来比较麻烦,不过可以优化到1800ns。代码写的比较挫。。。。

logn求lca:3200+ms

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <vector>
#include <queue>
#include <stack>
#include <algorithm>
using namespace std;
const int maxn=4e4+10;
const int maxm=1e5+10;

int n,m, siz;
vector<int> g[maxn];
int a[maxn], b[maxn], ans[maxm];
int tot[maxn], in[maxn];
int fa[maxn][20], dep[maxn];

struct Query
{
    int l, r, id;
    int st,ed;
    bool operator <(const Query& a) const
    {
        return st!=a.st? st<a.st: ed<a.ed; //先按左端点所在块先后排序,其次考虑又右端点所在块
    }
};

Query q[maxm];
int tag, bel[maxn];
int st[maxn], top;
int dfs(int u, int par, int d, int &cnt)
{
    dep[u]=d; fa[u][0]=par;
    int num=0;
    for(int i=0; i<g[u].size(); i++){
        int v=g[u][i];
        if(v!=par){
            num+=dfs(v, u, d+1, cnt);
            if(num>=siz){ //子树大小>=sqrt(n),分成一块
                for(int i=0; i<num; i++)
                    bel[st[--top]]=tag;
                tag++;
                num=0;
            }
        }
    }

    st[top++]=u;//记录子树遍历的点
    return num+1;
}

void init()
{
    for(int i=0; i<=n; i++) g[i].clear();
    memset(tot, 0, sizeof(tot));
    memset(in, 0, sizeof(in));

    siz=sqrt(n);

    for(int i=1;i<=n; i++) scanf("%d",&a[i]), b[i]=a[i];
    sort(b+1, b+n+1);
    for(int i=1; i<=n; i++)
        a[i]=lower_bound(b+1, b+n+1, a[i])-b;

    for(int i=0; i<n-1; i++){
        int u,v;
        scanf("%d%d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }

    int cnt=0; tag=top=0;
    int num=dfs(1, -1, 0, cnt);
    for(int i=0; i<num; i++)
        bel[st[--top]]=tag; //最后剩下的数也分成一块

    for(int i=1; i<20; i++){
        for(int u=1; u<=n; u++)
            if(fa[u][i-1]==-1)
                fa[u][i]=-1;
            else fa[u][i]=fa[fa[u][i-1]][i-1];
    }

    for(int i=0; i<m; i++){
        scanf("%d%d", &q[i].l, &q[i].r);
        if(bel[q[i].l]>bel[q[i].r])
            swap(q[i].l, q[i].r);
        q[i].id=i;
        q[i].st=bel[q[i].l];
        q[i].ed=bel[q[i].r];
    }
    sort(q, q+m);
}

int lca(int u, int v)
{
    if(dep[u]>dep[v]) swap(u, v);
    for(int i=0; i<20; i++)
        if((dep[v]-dep[u])>>i&1)
            v=fa[v][i];

    if(u==v) return u;
    for(int i=19; i>=0; i--){
        if(fa[u][i]!=fa[v][i]){
            u=fa[u][i];
            v=fa[v][i];
        }
    }
    return fa[u][0];
}

void solve()
{
    int res=0;
    int cu=1, cv=1;

    for(int i=0; i<m; i++){
        int nu=q[i].l, nv=q[i].r;
        int par=lca(cu, nu);
        while(cu!=par){
            if(in[cu]){
                if(--tot[a[cu]]==0)
                    res--;
            }
            else if(++tot[a[cu]]==1)
                res++;
            in[cu]^=1;
            cu=fa[cu][0];
        }

        cu=nu;
        while(cu!=par){
            if(in[cu]){
                if(--tot[a[cu]]==0)
                    res--;
            }
            else if(++tot[a[cu]]==1)
                res++;
            in[cu]^=1;
            cu=fa[cu][0];
        }
        cu=nu;


        par=lca(cv, nv);
        while(cv!=par){
            if(in[cv]){
                if(--tot[a[cv]]==0)
                    res--;
            }
            else if(++tot[a[cv]]==1)
                res++;
            in[cv]^=1;
            cv=fa[cv][0];
        }

        cv=nv;
        while(cv!=par){
            if(in[cv]){
                if(--tot[a[cv]]==0)
                    res--;
            }
            else if(++tot[a[cv]]==1)
                res++;
            in[cv]^=1;
            cv=fa[cv][0];
        }
        cv=nv;

        par=lca(cu, cv);
        ans[q[i].id]=res+(!tot[a[par]]);
    }
}

int main()
{
    while(scanf("%d%d", &n, &m)==2){
        init();
        solve();
        for(int i=0; i<m; i++)
            printf("%d\n", ans[i]);
    }
    return 0;
}


离线查询lca:1800+ms

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <string>
#include <vector>
#include <map>
#include <set>
#include <vector>
#include <queue>
#include <stack>
#include <algorithm>
using namespace std;
#pragma comment(linker, "/STACK:1024000000,1024000000")
typedef pair<int,int> P;
#define fir first
#define sec second
const int maxn=4e4+10;
const int maxm=1e5+10;

int n,m, siz;
vector<int> g[maxn];
int first[maxn],ltot=0, nxt[6*maxm];
P lq[6*maxm];//所有需要查询的lca,lq[i].first保存v,second保存查询的id

int a[maxn], b[maxn], ans[maxm];
int tot[maxn], in[maxn], fa1[maxn];
int fa[maxn], lca[3*maxm], col[maxn];
int bel[maxn],st[maxn],top=0;
struct Query
{
    int l, r, id;
    int st,ed;
    bool operator <(const Query& a) const
    {
        return st!=a.st? st<a.st: ed<a.ed;
    }
};

Query q[maxm];
int tag;
int dfs(int u, int par, int &cnt)//分块
{
    fa1[u]=par;
    int num=0;
    for(int i=0; i<g[u].size(); i++){
        int v=g[u][i];
        if(v!=par)
            num+=dfs(v, u, cnt);
        if(num>=siz){
            for(int i=0; i<num; i++)
                bel[st[--top]]=tag;
            tag++;
            num=0;
        }

    }
    st[top++]=u;
    return num+1;
}

int find(int u)
{
    return fa[u]==u?u:(fa[u]=find(fa[u]));
}

int unite(int x, int y)
{
    x=fa[x];
    y=fa[y];
    fa[y]=x;
}


void dfs2(int u, int par)//离线查询所有lca
{
    col[u]=1;
    for(int i=first[u]; i!=-1; i=nxt[i]){
        int v=lq[i].fir, id=lq[i].sec;
        if(!col[v]) continue;
        else if(col[v]==1){
            lca[id]=v;
        }
        else{
            lca[id]=find(v);
        }
    }

    for(int i=0; i<g[u].size(); i++){
        int v=g[u][i];
        if(v!=par)
            dfs2(v, u);
    }
    col[u]=2;
    unite(par, u);
}

void add(int u, int v, int id)//查询m<=1e5,数比较多所以用前向星实现优化
{
    lq[ltot]=P(v,id);
    nxt[ltot]=first[u];
    first[u]=ltot++;
}

void init()
{
    for(int i=0; i<=n; i++) g[i].clear();
    memset(tot, 0, sizeof(tot));
    memset(in, 0, sizeof(in));

    siz=sqrt(n);

    for(int i=1;i<=n; i++) scanf("%d", a+i), b[i]=a[i];
    sort(b+1, b+n+1);
    for(int i=1; i<=n; i++)
        a[i]=lower_bound(b+1, b+n+1, a[i])-b;

    for(int i=0; i<n-1; i++){
        int u,v;
        scanf("%d%d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }


    int cnt=0; top=0; tag=0;
    int num=dfs(1, -1, cnt);
    for(int i=0; i<num; i++)
        bel[st[--top]]=tag;

    for(int i=0; i<m; i++){
        scanf("%d%d", &q[i].l, &q[i].r);
        if(bel[q[i].l]>bel[q[i].r])
            swap(q[i].l, q[i].r);
        q[i].id=i;
        q[i].st=bel[q[i].l];
        q[i].ed=bel[q[i].r];
    }
    sort(q, q+m);

    cnt=0; ltot=0;
    memset(first, -1, sizeof(first));
    add(1, q[0].l, cnt);
    add(q[0].l, 1, cnt++);
    add(1, q[0].r, cnt);
    add(q[0].r, 1, cnt++);
    add(q[0].r, q[0].l, cnt);
    add(q[0].l, q[0].r, cnt++);
    //add(q[0].r, q[0].l, cnt++);


    for(int i=0; i<m-1; i++){
    add(q[i].l, q[i+1].l, cnt);//第i个查询左端点向第i+1个左端点转移,所以需要它们之间的lca
    add(q[i+1].l, q[i].l, cnt++);
    add(q[i].r, q[i+1].r, cnt);//第i个查询右端点向第i+1个右端点转移
    add(q[i+1].r, q[i].r, cnt++);
    add(q[i+1].r, q[i+1].l, cnt);//左端点和右端点的lca
    add(q[i+1].l, q[i+1].r,cnt++);
    }
    for(int i=0; i<=n; i++) fa[i]=i;
    memset(col, 0, sizeof(col));
    dfs2(1, 0);
}



void solve()
{
    int res=0;
    int cu=1, cv=1;

    for(int i=0; i<m; i++){
        int nu=q[i].l, nv=q[i].r;
        //cout<<lca[i*3]<<' '<<lca[i*3+1]<<' '<<lca[i*3+2]<<endl;
        int par=lca[i*3];
        while(cu!=par){
            if(in[cu]){
                if(--tot[a[cu]]==0)
                    res--;
            }
            else if(++tot[a[cu]]==1)
                res++;
            in[cu]^=1;
            cu=fa1[cu];
        }

        cu=nu;
        while(cu!=par){
            if(in[cu]){
                if(--tot[a[cu]]==0)
                    res--;
            }
            else if(++tot[a[cu]]==1)
                res++;
            in[cu]^=1;
            cu=fa1[cu];
        }
        cu=nu;


        par=lca[i*3+1];
        while(cv!=par){
            if(in[cv]){
                if(--tot[a[cv]]==0)
                    res--;
            }
            else if(++tot[a[cv]]==1)
                res++;
            in[cv]^=1;
            cv=fa1[cv];
        }

        cv=nv;
        while(cv!=par){
            if(in[cv]){
                if(--tot[a[cv]]==0)
                    res--;
            }
            else if(++tot[a[cv]]==1)
                res++;
            in[cv]^=1;
            cv=fa1[cv];
        }
        cv=nv;

        par=lca[i*3+2];
        ans[q[i].id]=res+(!tot[a[par]]);
    }
}

int main()
{
    while(cin>>n>>m){
        init();
        solve();
        for(int i=0; i<m; i++)
            printf("%d\n", ans[i]);
    }
    return 0;
}


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值