[HackerRank University CodeSprint]Counting On a Tree

题目大意

给定一棵 n 个节点的树,每个节点有颜色ci。有 q 个询问,每次给定(x,y,u,v),询问有多少个点对 (i,j) 满足
ij
ipath(x,y)
jpath(u,v)
ci=cj

1n105,1q5×104,1ci109


题目分析

这题一看不可做。
我们来思考一下这题有哪些方法可以做。
先不考虑 ij 的限制。

Algorithm 1

分别处理每一种颜色。枚举颜色,然后计算每个点到根路径上的同色点的个数。这种颜色对于一个询问的贡献就是两条路径点数相乘即可。
令颜色种数为 C ,那么时间复杂度是O(C(n+q))

Algorithm 2

对于单种颜色,令 sum(x) 到根路径上的同色点个数,询问形如

(sum(x)+sum(y)sum(lca(x,y))sum(fa(lca(x,y))))(sum(u)+sum(v)sum(lca(u,v))sum(fa(lca(u,v))))

将询问差分拆成16个。于是每个询问变成两条到根路径同色点对个数。
怎么做呢?将所有询问挂到对应的两个点上,DFS这棵树。每到一个点,先枚举每一个同色的点,那么显然如果一个询问的两个点分别在这两个点的子树内,那这对点就会对询问有贡献。因此我们对这个同色点的子树打一个 +1 标记。然后我们处理挂在这个点上的询问,如果询问的另一个点已经访问过(为了不算重),那么其对答案的贡献就是这个点的标记和(注意符号由差分询问时决定)。然后在退出节点的时候我们就撤销所有这个节点产生的标记。
注意为了处理特殊情况,一个点应该在访问完所有的儿子之后才标记为访问过。
至于处理标记,使用树状数组就可以了。
细节挺多的,请自己思考吧。
A 表示一种颜色最多的出现次数,这种算法的时间复杂度是O(nAlog2n)的。

Algorithm 3

我们发现上面两种方法都过不了,怎么办呢?
我们可以使用阈值均衡两种算法!
设置阈值 B ,如果一种颜色出现次数大于B,那么显然这种颜色种数小于等于 nB ,我们可以使用算法1。否则使用算法二。
时间复杂度是 O(nBn+Bnlog2n) 。阈值瞎选一个就好了,我选的是 n
但是这样我们还有一个问题没有考虑,就是 ij 怎么处理?
可以发现,对于一个询问的两条路径,算重的显然是两条路径交集的长度。那么我们可以瞎分类讨论一波(你讨厌分类讨论可以直接打树剖),从答案中减去就好了。
此题完美解决!

Algorithm n?

貌似这题在hackerrank上的标解是树链剖分(捂脸)。
貌似国外有人用树上莫队过了?怎么做,不会~


代码实现

好久没有打过这么复杂的代码了……

#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cctype>
#include <cmath>

using namespace std;

typedef long long LL;

int read()
{
    int x=0,f=1;
    char ch=getchar();
    while (!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();
    while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();
    return x*f;
}

int buf[30];

void write(LL x)
{
    if (x<0) putchar('-'),x=-x;
    for (;x;x/=10) buf[++buf[0]]=x%10;
    if (!buf[0]) buf[++buf[0]]=0;
    for (;buf[0];) putchar('0'+buf[buf[0]--]);
}

const int N=100050;
const int Q=50050;
const int M=N<<1;
const int EL=N<<1;
const int LGEL=18;
const int T=Q<<1;

struct query
{
    int qid,x,sign;
}qs[T<<4];

struct node
{
    int pid,col;
}ns[N];

bool operator<(node x,node y){return x.col<y.col;}

int last[N],fa[N],pos[N],ptr[N],deep[N],c[N],v[N],qlst[N],DFN[N],size[N];
int n,tot,el,lgel,idx,thr,qtot,ctp,q;
int euler[EL],LOG[EL];
bool mark[N],vis[N];
int tov[M],nxt[M];
int rmq[EL][LGEL];
int qnxt[T<<4];
int qy[Q][4];
LL ans[Q];

int lowbit(int x){return x&-x;}

struct Fenwick_tree
{
    int num[N];

    int query(int x)
    {
        int ret=0;
        for (;x;x-=lowbit(x)) ret+=num[x];
        return ret;
    }

    void modify(int x,int delta){for (;x<=n;x+=lowbit(x)) num[x]+=delta;}
}t;

void insert(int x,int y){tov[++tot]=y,nxt[tot]=last[x],last[x]=tot;}

void hang(int x,int y,int id,int sign){qs[++qtot].qid=id,qs[qtot].x=y,qs[qtot].sign=sign,qnxt[qtot]=qlst[x],qlst[x]=qtot;}

void build(int x,int y,int id,int sign)
{
    hang(x,y,id,sign);
    if (x!=y) hang(y,x,id,sign);
}

void dfs(int x)
{
    size[rmq[pos[euler[++el]=x]=el][0]=x]=1,DFN[x]=++idx;
    for (int i=last[x],y;i;i=nxt[i])
        if ((y=tov[i])!=fa[x])
            fa[y]=x,deep[y]=deep[x]+1,dfs(y),euler[++el]=x,rmq[el][0]=x,size[x]+=size[y];
}

void pre_rmq()
{
    lgel=trunc(log(el)/log(2));
    for (int j=1;j<=lgel;j++)
        for (int i=1;i+(1<<j)-1<=el;i++)
            rmq[i][j]=deep[rmq[i][j-1]]<deep[rmq[i+(1<<j-1)][j-1]]?rmq[i][j-1]:rmq[i+(1<<j-1)][j-1];

    LOG[1]=0;
    for (int i=2;i<=el;i++) LOG[i]=LOG[i-1]+(1<<LOG[i-1]+1==i);
}

int get_rmq(int l,int r)
{
    int lgr=LOG[r-l+1];
    return deep[rmq[l][lgr]]<deep[rmq[r-(1<<lgr)+1][lgr]]?rmq[l][lgr]:rmq[r-(1<<lgr)+1][lgr];
}

int lca(int x,int y)
{
    if ((x=pos[x])>(y=pos[y])) swap(x,y);
    return get_rmq(x,y);
}

void sum(int x)
{
    for (int i=last[x],y;i;i=nxt[i])
        if ((y=tov[i])!=fa[x]) v[y]+=v[x],sum(y);
}

int getsum(int x,int y)
{
    int z=lca(x,y);
    return v[x]+v[y]-v[z]-v[fa[z]];
}

void pre_color()
{
    sort(ns+1,ns+1+n);
    ns[0].col=0,ctp=0;
    for (int i=1;i<=n;i++)
    {
        if (ns[i].col!=ns[i-1].col) ptr[++ctp]=i;
        c[ns[i].pid]=ctp;
    }
    ptr[ctp+1]=n+1;
    for (int i=1;i<=ctp;i++)
        if (ptr[i+1]-ptr[i]<=thr) mark[i]=1;
        else
        {
            for (int j=1;j<=n;j++) v[j]=0;
            for (int j=ptr[i];j<ptr[i+1];j++) v[ns[j].pid]++;
            sum(1);
            for (int j=1;j<=q;j++) ans[j]+=1ll*getsum(qy[j][0],qy[j][1])*getsum(qy[j][2],qy[j][3]);
        }
}

void solve(int x)
{
    if (mark[c[x]]) for (int i=ptr[c[x]],y;i<ptr[c[x]+1];i++) t.modify(DFN[y=ns[i].pid],1),t.modify(DFN[y]+size[y],-1);
    for (int i=last[x],y;i;i=nxt[i])
        if ((y=tov[i])!=fa[x]) solve(y);
    vis[x]=1;
    for (int i=qlst[x];i;i=qnxt[i])
        if (vis[qs[i].x]) ans[qs[i].qid]+=t.query(DFN[qs[i].x])*qs[i].sign;
    if (mark[c[x]]) for (int i=ptr[c[x]],y;i<ptr[c[x]+1];i++) t.modify(DFN[y=ns[i].pid],-1),t.modify(DFN[y]+size[y],1);
}

void process()
{
    for (int i=1;i<=q;i++)
    {
        int x=qy[i][0],y=qy[i][1],z=lca(x,y),u=qy[i][2],v=qy[i][3],w=lca(u,v),tmp=0;
        if (DFN[w]<DFN[z]&&DFN[z]<=DFN[w]+size[w]-1)
        {
            if (lca(u,z)==z) tmp=max(deep[lca(x,u)]-deep[z],deep[lca(y,u)]-deep[z])+1;
            else if (lca(v,z)==z) tmp=max(deep[lca(x,v)]-deep[z],deep[lca(y,v)]-deep[z])+1;
        }
        else
            if (DFN[z]<=DFN[w]&&DFN[w]<=DFN[z]+size[z]-1)
                if (w==z)
                {
                    int a=lca(x,u),b=lca(x,v);
                    tmp=deep[a]-deep[z]+deep[b]-deep[z];
                    a=lca(y,u),b=lca(y,v);
                    tmp+=deep[a]-deep[z]+deep[b]-deep[z];
                    tmp++;
                }
                else
                {
                    int a=lca(x,u),b=lca(x,v);
                    if (a==w||b==w) tmp=max(deep[a]-deep[b],deep[b]-deep[a])+1;
                    a=lca(y,u),b=lca(y,v);
                    if (a==w||b==w) tmp=max(deep[a]-deep[b],deep[b]-deep[a])+1;
                }
        ans[i]-=tmp;
    }
}

int main()
{
    freopen("count.in","r",stdin),freopen("count.out","w",stdout);
    n=read(),q=read(),thr=trunc(sqrt(n));
    for (int i=1;i<=n;i++) ns[i].col=c[i]=read(),ns[i].pid=i;
    for (int i=1,x,y;i<n;i++)
    {
        x=read(),y=read();
        insert(x,y),insert(y,x);
    }
    fa[1]=0,deep[1]=1,dfs(1),pre_rmq();
    for (int i=1;i<=q;i++)
    {
        for (int j=0;j<4;j++) qy[i][j]=read();
        int x=qy[i][0],y=qy[i][1],z=lca(x,y),u=qy[i][2],v=qy[i][3],w=lca(u,v);
        build(x,u,i,1),build(x,w,i,-1),build(x,v,i,1);
        if (fa[w]) build(x,fa[w],i,-1);
        if (fa[z])
        {
            build(fa[z],u,i,-1),build(fa[z],w,i,1),build(fa[z],v,i,-1);
            if (fa[w]) build(fa[z],fa[w],i,1);
        }
        build(z,u,i,-1),build(z,w,i,1),build(z,v,i,-1);
        if (fa[w]) build(z,fa[w],i,1);
        build(y,u,i,1),build(y,w,i,-1),build(y,v,i,1);
        if (fa[w]) build(y,fa[w],i,-1);
    }
    pre_color(),solve(1),process();
    for (int i=1;i<=q;i++) write(ans[i]),putchar('\n');
    fclose(stdin),fclose(stdout);
    return 0;
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值