vjudge题面传送门:https://cn.vjudge.net/problem/CodeChef-BTREE
题目分析:sro wjmzbmr
这是道码农神题。首先考虑简化版的问题:如果给出一个点x,再给出一个距离d,如何求出距离x不超过d的点的个数?这可以用点分治解决。先用点分治预处理出每个连通块的所有点到其分治中心mid的深度数组f。f[mid][son][dep]=num表示分治中心mid的儿子son的子树中,深度为dep的点有num个。然后对f[mid][son]做前缀和,并令f[mid][0][x]为f[mid][son][x]的总和。查询x的答案的时候,先查以x为中心的连通块的答案,再逐个往上一级分治中心查与x不在同一个子树的点的答案,这个可以调用f数组差分求得。注意查询时深度要减去x与当前分治中心的距离。很明显f的总大小是 nlog(n) n log ( n ) 的,可以开一个二维vector存下。这样时间复杂度为预处理 O(nlog(n)) O ( n log ( n ) ) ,单次询问 O(log(n)) O ( log ( n ) ) 。
本题中每次询问给出的是一个点集,并且限制了点集总大小,很容易想到建虚树。建出虚树后可以在上面做两次DFS,算出虚树上每个点能延伸多远。但DFS完后直接对每个点进行查询会算重。如何减去重复部分呢?
考虑虚树上一条边(u,v),其中u是v的父亲。如果u和v的覆盖部分有交,可以在原树的链(u,v)上二分mid,使得v跨过mid比u跨过mid延伸得远(或一样远)。由于每一条边的边权都是1,所以要么u和v跨过mid后都能往外延伸len;要么u能延伸len,而v能延伸len+1。对于第一种情况,直接将答案减去mid往外延伸len后的结果即可。第二种情况比较难处理,但我们发现len和len+1只相差1。对此,一开始建树的时候,可以在每条边上增设一个虚点,并令其权重为0,使其不影响答案,然后每次把读入的r乘以2。这样第二种情况,mid就会出现在某个虚点处。
一开始我和tututu讨论的时候,没有注意到边权为1,但貌似也能做。如果边权不为1,就要将vector改成set,时间复杂度要多一个 log(n) log ( n ) 。并且由于不能用单纯地建虚点的方法,需要用主席树维护某个点子树中深度小于等于某个值的点的个数,然后用一些容斥+分类讨论解决。
一开始计划2h码完,结果用了将近3h,主要是虚树那个部分写得比较慢。最后因为建虚树之前没有重置内存池的指针RE了一发(我已经试过好多次因为这个问题RE了QAQ)。
CODE:
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
#include<vector>
using namespace std;
const int maxn=100100;
const int maxl=22;
struct edge
{
int obj,len;
edge *Next;
} e[maxn<<1];
edge *head[maxn];
int cur=-1,cnt;
vector < vector <int> > sum[maxn];
int Mid[maxn][maxl];
int Id[maxn][maxl];
int D[maxn][maxl];
int num[maxn];
int Size[maxn];
int max_Size[maxn];
bool vis[maxn];
int tree[maxn];
int Ts;
int up[maxn][maxl];
int dep[maxn];
int dfn[maxn];
int Time=0;
int Fa[maxn];
int Node[maxn];
int pn;
int sak[maxn];
int tail;
int dis[maxn];
int n,q,k;
void Add(int x,int y,int z)
{
cur++;
e[cur].obj=y;
e[cur].len=z;
e[cur].Next=head[x];
head[x]=e+cur;
}
void Dfs(int node)
{
dfn[node]=++Time;
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
if (son==up[node][0]) continue;
up[son][0]=node;
dep[son]=dep[node]+1;
Dfs(son);
}
}
void Find(int node,int fa)
{
tree[++Ts]=node;
Size[node]=1;
max_Size[node]=0;
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
if ( son==fa || vis[son] ) continue;
Find(son,node);
Size[node]+=Size[son];
max_Size[node]=max(max_Size[node],Size[son]);
}
}
void Work(int node,int fa,int Dep,int id,int mid)
{
num[node]++;
Mid[node][ num[node] ]=mid;
Id[node][ num[node] ]=id;
D[node][ num[node] ]=Dep;
if (node<=n)
{
while (sum[mid][id].size()<=Dep) sum[mid][id].push_back(0);
sum[mid][id][Dep]++;
while (sum[mid][0].size()<=Dep) sum[mid][0].push_back(0);
sum[mid][0][Dep]++;
}
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
if ( son==fa || vis[son] ) continue;
Work(son,node,Dep+1,id,mid);
}
}
void Solve(int node)
{
Ts=0;
Find(node,node);
if (Ts==1)
{
sum[node].push_back( vector <int> () );
sum[node][0].push_back(0);
if (node<=n) sum[node][0][0]++;
return;
}
int root=tree[1];
for (int i=1; i<=Ts; i++)
{
int x=tree[i];
max_Size[x]=max(max_Size[x],Size[node]-Size[x]);
if (max_Size[x]<max_Size[root]) root=x;
}
int x=0;
sum[root].push_back( vector <int> () );
sum[root][0].push_back(0);
if (root<=n) sum[root][0][0]++;
for (edge *p=head[root]; p; p=p->Next)
{
int son=p->obj;
if (vis[son]) continue;
x++;
sum[root].push_back( vector <int> () );
Work(son,root,1,x,root);
}
for (int i=0; i<=x; i++)
for (int j=1; j<sum[root][i].size(); j++)
sum[root][i][j]+=sum[root][i][j-1];
vis[root]=true;
for (edge *p=head[root]; p; p=p->Next)
{
int son=p->obj;
if (!vis[son]) Solve(son);
}
}
bool Comp(int x,int y)
{
return (dfn[x]<dfn[y]);
}
int Lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for (int j=maxl-1; j>=0; j--)
if (dep[ up[x][j] ]>=dep[y]) x=up[x][j];
if (x==y) return x;
for (int j=maxl-1; j>=0; j--)
if (up[x][j]!=up[y][j]) x=up[x][j],y=up[y][j];
return up[x][0];
}
void Build()
{
sort(Node+1,Node+k+1,Comp);
tail=1;
sak[tail]=Node[1];
pn=k;
for (int i=2; i<=k; i++)
{
int x=Node[i],p=Lca(x,sak[tail]);
int Last=0;
while (dep[ sak[tail] ]>dep[p]) Last=sak[tail--];
if (sak[tail]!=p) Fa[p]=sak[tail],Node[++pn]=sak[++tail]=p;
if (Last) Fa[Last]=p;
Fa[x]=p;
sak[++tail]=x;
}
Fa[ sak[1] ]=0;
for (int i=1; i<=pn; i++)
{
int x=Node[i];
if (Fa[x]) Add(Fa[x],x,-dep[ Fa[x] ]+dep[x]);
}
}
void Calc1(int node)
{
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
Calc1(son);
dis[node]=max(dis[node],dis[son]-p->len);
}
}
void Calc2(int node)
{
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
dis[son]=max(dis[son],dis[node]-p->len);
Calc2(son);
}
}
int Ask(int node,int r)
{
int tu=sum[node][0].size();
int x=min(r,tu-1);
int ans=0;
if (x>=0) ans+=sum[node][0][x];
for (int i=1; i<=num[node]; i++)
{
int mid=Mid[node][i];
int id=Id[node][i];
int d=D[node][i];
tu=sum[mid][0].size();
x=min(r-d,tu-1);
if (x>=0) ans+=sum[mid][0][x];
tu=sum[mid][id].size();
x=min(r-d,tu-1);
if (x>=0) ans-=sum[mid][id][x];
}
return ans;
}
int Jump(int x,int y)
{
if (dis[x]+dis[y]<dep[y]-dep[x]) return -1;
int z=y;
for (int j=maxl-1; j>=0; j--)
{
int w=up[z][j];
if ( dep[w]>=dep[x] &&
dis[y]-(dep[y]-dep[w])>=dis[x]-(dep[w]-dep[x]) ) z=w;
}
return z;
}
int main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
scanf("%d",&n);
for (int i=1; i<(n<<1); i++) head[i]=NULL;
cnt=n;
for (int i=1; i<n; i++)
{
int x,y;
scanf("%d%d",&x,&y);
cnt++;
Add(x,cnt,1);
Add(cnt,x,1);
Add(y,cnt,1);
Add(cnt,y,1);
}
dep[1]=1;
Dfs(1);
for (int j=1; j<maxl; j++)
for (int i=1; i<=cnt; i++)
up[i][j]=up[ up[i][j-1] ][j-1];
Solve(1);
for (int i=1; i<=cnt; i++) vis[i]=false,head[i]=NULL;
//for (int i=1; i<=cnt; i++) printf("%d\n",sum[i].size());
scanf("%d",&q);
while (q--)
{
cur=-1; //!!!
scanf("%d",&k);
for (int i=1; i<=k; i++)
scanf("%d",&Node[i]),
scanf("%d",&dis[ Node[i] ]),
vis[ Node[i] ]=true,
dis[ Node[i] ]<<=1;
Build();
for (int i=1; i<=pn; i++)
if (!vis[ Node[i] ]) dis[ Node[i] ]=-1;
int root=sak[1];
Calc1(root);
Calc2(root);
int ans=0;
for (int i=1; i<=pn; i++)
{
int x=Node[i],p=Fa[x];
if (dis[x]>=0) ans+=Ask(x,dis[x]);
if ( dis[x]>=0 && p && dis[p]>=0 )
{
int mid=Jump(p,x);
if (mid!=-1) ans-=Ask(mid,dis[x]-(dep[x]-dep[mid]));
}
}
printf("%d\n",ans);
for (int i=1; i<=pn; i++) vis[ Node[i] ]=false,head[ Node[i] ]=NULL;
}
return 0;
}