题目链接如下:
先用dfs确定每个节点的序号编号,并且可以获得每个节点可以包括的子树节点区间范围,再用线段树建立一棵树。
在第一次建立的时候我们记录每个节点的深度,然后再进行一次dfs,这次dfs用来更新以不同节点为根时,距离的维护,利用子树距离减1,非子树距离加1的方法进行更新距离和,最大距离和最小距离。同时不断建立树来打上懒标记,这里打标记的顺序也是dfs的,意味着我们query的时候要进行标记的计算,因为子树并没有更新根节点打上的懒标记。
最后利用补集的方式求出good node 到其余节点的距离。
参考:
hdu 5756(主席树)_top628LJ的博客-CSDN博客
代码如下所示:
#include <iostream>
#include <algorithm>
#include <vector>
#include <cstring>
#include <cstdlib>
using namespace std;
#define inf 0x3f3f3f3f
#define MP make_pair
#define mid ((l+r)>>1)
#define pii pair<int,int>
const int MAXN = 100020;
const int MAXM = 500020;
int n,q;
// fst记录一个节点连接的第一条边 vv记录连接到的节点 nxt记录下一条边索引
// fst大小和节点数目一致
int fst[MAXN], vv[MAXM], nxt[MAXM], e;
// st记录节点dfs序的左端点 ed代表右端点 dep表示节点深度 lab表示左端点对应的节点号
int st[MAXN],ed[MAXN],dep[MAXN],lab[MAXN],dc;
// sum mx mi代表和 最大值 最小值 add表示懒标记 ch代表子树位置
int sum[MAXN*50],mx[MAXN*50],mi[MAXN*50],add[MAXN*50],ch[MAXN*50][2];
int rt[MAXN], tot;
pii a[MAXN*5];
void init(){
memset(fst, -1, sizeof(fst));
e=0;
}
void adde(int u,int v){
vv[e]=v;
nxt[e]=fst[u];
fst[u]=e++;
}
void dfs(int u, int p, int d){
dep[u]=d;
st[u]=++dc;
lab[dc]=u;
for (int i = fst[u]; ~i ; i = nxt[i]) {
int v = vv[i];
if (v == p) continue;
dfs(v,u,d+1);
}
ed[u]=dc;
}
void push_up(int l,int r,int rt){
int l_rt=ch[rt][0], r_rt=ch[rt][1];
sum[rt]=sum[l_rt]+sum[r_rt]+(r-l+1)*add[rt];
mx[rt]=max(mx[l_rt], mx[r_rt])+add[rt];
mi[rt]=min(mi[l_rt], mi[r_rt])+add[rt];
}
int build(int l,int r){
int k=++tot;
if (l==r){
sum[k]=mx[k]=mi[k]=dep[lab[l]]-1;
ch[k][0]=ch[k][1]=add[k]=0;
return k;
}
sum[k]=mx[k]=mi[k]=add[k]=0;
ch[k][0]=build(l,mid);
ch[k][1]=build(mid+1, r);
push_up(l,r,k);
return k;
}
int update(int ul, int ur, int val,int l,int r,int rt){
int k=++tot;
sum[k]=sum[rt], mx[k]=mx[rt], mi[k]=mi[rt], add[k]=add[rt];
ch[k][0]=ch[rt][0], ch[k][1]=ch[rt][1];
if (ul==l && ur==r){
sum[k]+=(r-l+1)*val;
mx[k]+=val;
mi[k]+=val;
add[k]+=val;
return k;
}
if (ur<=mid) {
ch[k][0]=update(ul,ur,val,l,mid,ch[rt][0]);
} else if (ul>mid) {
ch[k][1]=update(ul,ur,val,mid+1,r,ch[rt][1]);
} else {
ch[k][0]=update(ul,mid,val,l,mid,ch[rt][0]);
ch[k][1]=update(mid+1,ur,val,mid+1,r,ch[rt][1]);
}
push_up(l,r,k);
return k;
}
void dfs1(int u,int p){
for (int i = fst[u]; ~i ; i=nxt[i]) {
int v=vv[i];
if (v==p) continue;
rt[v]= update(st[v],ed[v],-1,1,n,rt[u]);
if (st[v]>1){
rt[v]= update(1,st[v]-1,1,1,n,rt[v]);
}
if (ed[v]<n){
rt[v]= update(ed[v]+1,n,1,1,n,rt[v]);
}
dfs1(v,u);
}
}
int query(int ul,int ur,int t,int l,int r,int rt){
if (ul==l && ur==r){
if(t == 1) return sum[rt];
else if(t == 2) return mi[rt];
else return mx[rt];
}
if (ur<=mid){
int ret = query(ul,ur,t,l,mid,ch[rt][0]);
if (t==1) ret+=(ur-ul+1)*add[rt];
else ret+=add[rt];
return ret;
}else if (ul>mid) {
int ret = query(ul,ur,t,mid+1,r,ch[rt][1]);
if (t==1) ret+=(ur-ul+1)*add[rt];
else ret+=add[rt];
return ret;
}else {
int ret1 = query(ul,mid,t,l,mid,ch[rt][0]);
int ret2 = query(mid+1,ur,t,mid+1,r,ch[rt][1]);
if (t==1) return ret1+ret2+(ur-ul+1)*add[rt];
else if (t==2) return min(ret1, ret2)+add[rt];
else return max(ret1,ret2)+add[rt];
}
}
int main(){
int u,v;
while (~scanf("%d %d", &n, &q)){
init();
for (int i = 1; i < n; ++i) {
scanf("%d %d", &u, &v);
adde(u,v);
adde(v,u);
}
dc=0;
dfs(1,-1,1);
tot=0;
rt[1]= build(1,n);
dfs1(1,-1);
int K,P,T,ans=0,x;
while (q--){
scanf("%d %d %d",&K,&P,&T);
// cout << "K, P, T:" << K << "," << P << "," << T <<endl;
P=(P+ans)%n+1;
bool ok= false;
for (int i = 1; i <= K; ++i) {
scanf("%d",&x);
a[i]=MP(st[x],ed[x]);
// cout<<"i:"<<i<<",x:"<<x<<",[st,ed]:"<<"["<<st[x]<<","<<ed[x]<<"]"<<endl;
if (x==1){
ok=true;
}
}
if (ok){
puts("-1"); ans = 0;
continue;
}
sort(a+1,a+1+K);
a[++K]=MP(n+1,n+1);
if (T==1) ans=0;
else if (T==2) ans=inf;
else ans=-inf;
int pre=1;
for (int i = 1; i <= K; ++i) {
if (pre<a[i].first){
int res = query(pre,a[i].first-1, T,1,n,rt[P]);
// cout<<"In for i:"<<i<<",[st,ed]:"<<"["<<a[i].first<<","<<a[i].second<<"]"<<endl;
// cout << "res:"<<res<<endl;
if (T==1) {
ans+= res;
}else if (T==2) {
ans= min(ans,res);
}else {
ans= max(ans,res);
}
}
pre=max(pre,a[i].second+1);
}
printf("%d\n", ans);
}
}
return 0;
}