CodeChef Union on Tree (虚树+点分治)

vjudge题面传送门:https://cn.vjudge.net/problem/CodeChef-BTREE


题目分析:sro wjmzbmr

这是道码农神题。首先考虑简化版的问题:如果给出一个点x,再给出一个距离d,如何求出距离x不超过d的点的个数?这可以用点分治解决。先用点分治预处理出每个连通块的所有点到其分治中心mid的深度数组f。f[mid][son][dep]=num表示分治中心mid的儿子son的子树中,深度为dep的点有num个。然后对f[mid][son]做前缀和,并令f[mid][0][x]为f[mid][son][x]的总和。查询x的答案的时候,先查以x为中心的连通块的答案,再逐个往上一级分治中心查与x不在同一个子树的点的答案,这个可以调用f数组差分求得。注意查询时深度要减去x与当前分治中心的距离。很明显f的总大小是 nlog(n) n log ⁡ ( n ) 的,可以开一个二维vector存下。这样时间复杂度为预处理 O(nlog(n)) O ( n log ⁡ ( n ) ) ,单次询问 O(log(n)) O ( log ⁡ ( n ) )

本题中每次询问给出的是一个点集,并且限制了点集总大小,很容易想到建虚树。建出虚树后可以在上面做两次DFS,算出虚树上每个点能延伸多远。但DFS完后直接对每个点进行查询会算重。如何减去重复部分呢?

考虑虚树上一条边(u,v),其中u是v的父亲。如果u和v的覆盖部分有交,可以在原树的链(u,v)上二分mid,使得v跨过mid比u跨过mid延伸得远(或一样远)。由于每一条边的边权都是1,所以要么u和v跨过mid后都能往外延伸len;要么u能延伸len,而v能延伸len+1。对于第一种情况,直接将答案减去mid往外延伸len后的结果即可。第二种情况比较难处理,但我们发现len和len+1只相差1。对此,一开始建树的时候,可以在每条边上增设一个虚点,并令其权重为0,使其不影响答案,然后每次把读入的r乘以2。这样第二种情况,mid就会出现在某个虚点处。

一开始我和tututu讨论的时候,没有注意到边权为1,但貌似也能做。如果边权不为1,就要将vector改成set,时间复杂度要多一个 log(n) log ⁡ ( n ) 。并且由于不能用单纯地建虚点的方法,需要用主席树维护某个点子树中深度小于等于某个值的点的个数,然后用一些容斥+分类讨论解决。

一开始计划2h码完,结果用了将近3h,主要是虚树那个部分写得比较慢。最后因为建虚树之前没有重置内存池的指针RE了一发(我已经试过好多次因为这个问题RE了QAQ)。


CODE:

#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
#include<vector>
using namespace std;

const int maxn=100100;
const int maxl=22;

struct edge
{
    int obj,len;
    edge *Next;
} e[maxn<<1];
edge *head[maxn];
int cur=-1,cnt;

vector < vector <int> > sum[maxn];

int Mid[maxn][maxl];
int Id[maxn][maxl];
int D[maxn][maxl];
int num[maxn];

int Size[maxn];
int max_Size[maxn];

bool vis[maxn];
int tree[maxn];
int Ts;

int up[maxn][maxl];
int dep[maxn];
int dfn[maxn];
int Time=0;

int Fa[maxn];
int Node[maxn];
int pn;

int sak[maxn];
int tail;

int dis[maxn];
int n,q,k;

void Add(int x,int y,int z)
{
    cur++;
    e[cur].obj=y;
    e[cur].len=z;
    e[cur].Next=head[x];
    head[x]=e+cur;
}

void Dfs(int node)
{
    dfn[node]=++Time;
    for (edge *p=head[node]; p; p=p->Next)
    {
        int son=p->obj;
        if (son==up[node][0]) continue;

        up[son][0]=node;
        dep[son]=dep[node]+1;
        Dfs(son);
    }
}

void Find(int node,int fa)
{
    tree[++Ts]=node;
    Size[node]=1;
    max_Size[node]=0;
    for (edge *p=head[node]; p; p=p->Next)
    {
        int son=p->obj;
        if ( son==fa || vis[son] ) continue;

        Find(son,node);
        Size[node]+=Size[son];
        max_Size[node]=max(max_Size[node],Size[son]);
    }
}

void Work(int node,int fa,int Dep,int id,int mid)
{
    num[node]++;
    Mid[node][ num[node] ]=mid;
    Id[node][ num[node] ]=id;
    D[node][ num[node] ]=Dep;

    if (node<=n)
    {
        while (sum[mid][id].size()<=Dep) sum[mid][id].push_back(0);
        sum[mid][id][Dep]++;
        while (sum[mid][0].size()<=Dep) sum[mid][0].push_back(0);
        sum[mid][0][Dep]++;
    }

    for (edge *p=head[node]; p; p=p->Next)
    {
        int son=p->obj;
        if ( son==fa || vis[son] ) continue;
        Work(son,node,Dep+1,id,mid);
    }
}

void Solve(int node)
{
    Ts=0;
    Find(node,node);
    if (Ts==1)
    {
        sum[node].push_back( vector <int> () );
        sum[node][0].push_back(0);
        if (node<=n) sum[node][0][0]++;
        return;
    }

    int root=tree[1];
    for (int i=1; i<=Ts; i++)
    {
        int x=tree[i];
        max_Size[x]=max(max_Size[x],Size[node]-Size[x]);
        if (max_Size[x]<max_Size[root]) root=x;
    }

    int x=0;
    sum[root].push_back( vector <int> () );
    sum[root][0].push_back(0);
    if (root<=n) sum[root][0][0]++;
    for (edge *p=head[root]; p; p=p->Next)
    {
        int son=p->obj;
        if (vis[son]) continue;
        x++;
        sum[root].push_back( vector <int> () );
        Work(son,root,1,x,root);
    }

    for (int i=0; i<=x; i++)
        for (int j=1; j<sum[root][i].size(); j++)
            sum[root][i][j]+=sum[root][i][j-1];

    vis[root]=true;
    for (edge *p=head[root]; p; p=p->Next)
    {
        int son=p->obj;
        if (!vis[son]) Solve(son);
    }
}

bool Comp(int x,int y)
{
    return (dfn[x]<dfn[y]);
}

int Lca(int x,int y)
{
    if (dep[x]<dep[y]) swap(x,y);
    for (int j=maxl-1; j>=0; j--)
        if (dep[ up[x][j] ]>=dep[y]) x=up[x][j];
    if (x==y) return x;
    for (int j=maxl-1; j>=0; j--)
        if (up[x][j]!=up[y][j]) x=up[x][j],y=up[y][j];
    return up[x][0];
}

void Build()
{
    sort(Node+1,Node+k+1,Comp);
    tail=1;
    sak[tail]=Node[1];
    pn=k;
    for (int i=2; i<=k; i++)
    {
        int x=Node[i],p=Lca(x,sak[tail]);
        int Last=0;
        while (dep[ sak[tail] ]>dep[p]) Last=sak[tail--];
        if (sak[tail]!=p) Fa[p]=sak[tail],Node[++pn]=sak[++tail]=p;
        if (Last) Fa[Last]=p;
        Fa[x]=p;
        sak[++tail]=x;
    }
    Fa[ sak[1] ]=0;
    for (int i=1; i<=pn; i++)
    {
        int x=Node[i];
        if (Fa[x]) Add(Fa[x],x,-dep[ Fa[x] ]+dep[x]);
    }
}

void Calc1(int node)
{
    for (edge *p=head[node]; p; p=p->Next)
    {
        int son=p->obj;
        Calc1(son);
        dis[node]=max(dis[node],dis[son]-p->len);
    }
}

void Calc2(int node)
{
    for (edge *p=head[node]; p; p=p->Next)
    {
        int son=p->obj;
        dis[son]=max(dis[son],dis[node]-p->len);
        Calc2(son);
    }
}

int Ask(int node,int r)
{
    int tu=sum[node][0].size();
    int x=min(r,tu-1);
    int ans=0;
    if (x>=0) ans+=sum[node][0][x];
    for (int i=1; i<=num[node]; i++)
    {
        int mid=Mid[node][i];
        int id=Id[node][i];
        int d=D[node][i];
        tu=sum[mid][0].size();
        x=min(r-d,tu-1);
        if (x>=0) ans+=sum[mid][0][x];
        tu=sum[mid][id].size();
        x=min(r-d,tu-1);
        if (x>=0) ans-=sum[mid][id][x];
    }
    return ans;
}

int Jump(int x,int y)
{
    if (dis[x]+dis[y]<dep[y]-dep[x]) return -1;
    int z=y;
    for (int j=maxl-1; j>=0; j--)
    {
        int w=up[z][j];
        if ( dep[w]>=dep[x] &&
             dis[y]-(dep[y]-dep[w])>=dis[x]-(dep[w]-dep[x]) ) z=w;
    }
    return z;
}

int main()
{
    freopen("tree.in","r",stdin);
    freopen("tree.out","w",stdout);

    scanf("%d",&n);
    for (int i=1; i<(n<<1); i++) head[i]=NULL;
    cnt=n;
    for (int i=1; i<n; i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        cnt++;
        Add(x,cnt,1);
        Add(cnt,x,1);
        Add(y,cnt,1);
        Add(cnt,y,1);
    }

    dep[1]=1;
    Dfs(1);

    for (int j=1; j<maxl; j++)
        for (int i=1; i<=cnt; i++)
            up[i][j]=up[ up[i][j-1] ][j-1];

    Solve(1);
    for (int i=1; i<=cnt; i++) vis[i]=false,head[i]=NULL;

    //for (int i=1; i<=cnt; i++) printf("%d\n",sum[i].size());

    scanf("%d",&q);
    while (q--)
    {
        cur=-1; //!!!
        scanf("%d",&k);
        for (int i=1; i<=k; i++)
            scanf("%d",&Node[i]),
            scanf("%d",&dis[ Node[i] ]),
            vis[ Node[i] ]=true,
            dis[ Node[i] ]<<=1;

        Build();
        for (int i=1; i<=pn; i++)
            if (!vis[ Node[i] ]) dis[ Node[i] ]=-1;

        int root=sak[1];
        Calc1(root);
        Calc2(root);

        int ans=0;
        for (int i=1; i<=pn; i++)
        {
            int x=Node[i],p=Fa[x];
            if (dis[x]>=0) ans+=Ask(x,dis[x]);
            if ( dis[x]>=0 && p && dis[p]>=0 )
            {
                int mid=Jump(p,x);
                if (mid!=-1) ans-=Ask(mid,dis[x]-(dep[x]-dep[mid]));
            }
        }
        printf("%d\n",ans);

        for (int i=1; i<=pn; i++) vis[ Node[i] ]=false,head[ Node[i] ]=NULL;
    }

    return 0;
}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值