# CodeChef Union on Tree （虚树+点分治）

vjudge题面传送门：https://cn.vjudge.net/problem/CodeChef-BTREE

CODE：

#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
#include<vector>
using namespace std;

const int maxn=100100;
const int maxl=22;

struct edge
{
int obj,len;
edge *Next;
} e[maxn<<1];
int cur=-1,cnt;

vector < vector <int> > sum[maxn];

int Mid[maxn][maxl];
int Id[maxn][maxl];
int D[maxn][maxl];
int num[maxn];

int Size[maxn];
int max_Size[maxn];

bool vis[maxn];
int tree[maxn];
int Ts;

int up[maxn][maxl];
int dep[maxn];
int dfn[maxn];
int Time=0;

int Fa[maxn];
int Node[maxn];
int pn;

int sak[maxn];
int tail;

int dis[maxn];
int n,q,k;

{
cur++;
e[cur].obj=y;
e[cur].len=z;
}

void Dfs(int node)
{
dfn[node]=++Time;
{
int son=p->obj;
if (son==up[node][0]) continue;

up[son][0]=node;
dep[son]=dep[node]+1;
Dfs(son);
}
}

void Find(int node,int fa)
{
tree[++Ts]=node;
Size[node]=1;
max_Size[node]=0;
{
int son=p->obj;
if ( son==fa || vis[son] ) continue;

Find(son,node);
Size[node]+=Size[son];
max_Size[node]=max(max_Size[node],Size[son]);
}
}

void Work(int node,int fa,int Dep,int id,int mid)
{
num[node]++;
Mid[node][ num[node] ]=mid;
Id[node][ num[node] ]=id;
D[node][ num[node] ]=Dep;

if (node<=n)
{
while (sum[mid][id].size()<=Dep) sum[mid][id].push_back(0);
sum[mid][id][Dep]++;
while (sum[mid][0].size()<=Dep) sum[mid][0].push_back(0);
sum[mid][0][Dep]++;
}

{
int son=p->obj;
if ( son==fa || vis[son] ) continue;
Work(son,node,Dep+1,id,mid);
}
}

void Solve(int node)
{
Ts=0;
Find(node,node);
if (Ts==1)
{
sum[node].push_back( vector <int> () );
sum[node][0].push_back(0);
if (node<=n) sum[node][0][0]++;
return;
}

int root=tree[1];
for (int i=1; i<=Ts; i++)
{
int x=tree[i];
max_Size[x]=max(max_Size[x],Size[node]-Size[x]);
if (max_Size[x]<max_Size[root]) root=x;
}

int x=0;
sum[root].push_back( vector <int> () );
sum[root][0].push_back(0);
if (root<=n) sum[root][0][0]++;
{
int son=p->obj;
if (vis[son]) continue;
x++;
sum[root].push_back( vector <int> () );
Work(son,root,1,x,root);
}

for (int i=0; i<=x; i++)
for (int j=1; j<sum[root][i].size(); j++)
sum[root][i][j]+=sum[root][i][j-1];

vis[root]=true;
{
int son=p->obj;
if (!vis[son]) Solve(son);
}
}

bool Comp(int x,int y)
{
return (dfn[x]<dfn[y]);
}

int Lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for (int j=maxl-1; j>=0; j--)
if (dep[ up[x][j] ]>=dep[y]) x=up[x][j];
if (x==y) return x;
for (int j=maxl-1; j>=0; j--)
if (up[x][j]!=up[y][j]) x=up[x][j],y=up[y][j];
return up[x][0];
}

void Build()
{
sort(Node+1,Node+k+1,Comp);
tail=1;
sak[tail]=Node[1];
pn=k;
for (int i=2; i<=k; i++)
{
int x=Node[i],p=Lca(x,sak[tail]);
int Last=0;
while (dep[ sak[tail] ]>dep[p]) Last=sak[tail--];
if (sak[tail]!=p) Fa[p]=sak[tail],Node[++pn]=sak[++tail]=p;
if (Last) Fa[Last]=p;
Fa[x]=p;
sak[++tail]=x;
}
Fa[ sak[1] ]=0;
for (int i=1; i<=pn; i++)
{
int x=Node[i];
}
}

void Calc1(int node)
{
{
int son=p->obj;
Calc1(son);
dis[node]=max(dis[node],dis[son]-p->len);
}
}

void Calc2(int node)
{
{
int son=p->obj;
dis[son]=max(dis[son],dis[node]-p->len);
Calc2(son);
}
}

{
int tu=sum[node][0].size();
int x=min(r,tu-1);
int ans=0;
if (x>=0) ans+=sum[node][0][x];
for (int i=1; i<=num[node]; i++)
{
int mid=Mid[node][i];
int id=Id[node][i];
int d=D[node][i];
tu=sum[mid][0].size();
x=min(r-d,tu-1);
if (x>=0) ans+=sum[mid][0][x];
tu=sum[mid][id].size();
x=min(r-d,tu-1);
if (x>=0) ans-=sum[mid][id][x];
}
return ans;
}

int Jump(int x,int y)
{
if (dis[x]+dis[y]<dep[y]-dep[x]) return -1;
int z=y;
for (int j=maxl-1; j>=0; j--)
{
int w=up[z][j];
if ( dep[w]>=dep[x] &&
dis[y]-(dep[y]-dep[w])>=dis[x]-(dep[w]-dep[x]) ) z=w;
}
return z;
}

int main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);

scanf("%d",&n);
for (int i=1; i<(n<<1); i++) head[i]=NULL;
cnt=n;
for (int i=1; i<n; i++)
{
int x,y;
scanf("%d%d",&x,&y);
cnt++;
}

dep[1]=1;
Dfs(1);

for (int j=1; j<maxl; j++)
for (int i=1; i<=cnt; i++)
up[i][j]=up[ up[i][j-1] ][j-1];

Solve(1);
for (int i=1; i<=cnt; i++) vis[i]=false,head[i]=NULL;

//for (int i=1; i<=cnt; i++) printf("%d\n",sum[i].size());

scanf("%d",&q);
while (q--)
{
cur=-1; //!!!
scanf("%d",&k);
for (int i=1; i<=k; i++)
scanf("%d",&Node[i]),
scanf("%d",&dis[ Node[i] ]),
vis[ Node[i] ]=true,
dis[ Node[i] ]<<=1;

Build();
for (int i=1; i<=pn; i++)
if (!vis[ Node[i] ]) dis[ Node[i] ]=-1;

int root=sak[1];
Calc1(root);
Calc2(root);

int ans=0;
for (int i=1; i<=pn; i++)
{
int x=Node[i],p=Fa[x];
if ( dis[x]>=0 && p && dis[p]>=0 )
{
int mid=Jump(p,x);
}
}
printf("%d\n",ans);

for (int i=1; i<=pn; i++) vis[ Node[i] ]=false,head[ Node[i] ]=NULL;
}

return 0;
}

©️2019 CSDN 皮肤主题: 技术黑板 设计师: CSDN官方博客