作为此次 NOIP 模拟的最后一道题,宫水三叶决定把题意说得简单一点。
给一棵大小为 n n n 的以 r t rt rt 为根的树。
有
m
m
m 组询问,每次询问
l
,
r
,
x
l,r,x
l,r,x,你要回答有多少
l
≤
a
<
b
≤
r
l \le a < b \le r
l≤a<b≤r,满足
a
,
b
a,b
a,b 的最近公共祖先为
x
x
x 。
第一行三个整数
n
,
m
,
r
t
n,m,rt
n,m,rt 。
接下来 n − 1 n-1 n−1 行,每行两个整数 x i , y i x_i,y_i xi,yi ,表示一条边。
接下来 m m m 行,每行三个整数 l i , r i , x i l_i,r_i,x_i li,ri,xi ,表示一组询问。
输出共 m m m 行,第 i i i 行表示第 i i i 个询问的答案。
样例输入 1
10 10 7
4 2
10 4
3 2
6 10
9 2
7 3
1 4
8 2
5 3
8 10 10
2 6 2
3 6 2
4 6 4
3 10 2
8 8 10
3 10 4
2 3 2
2 6 4
1 7 10
样例输出 1
0
2
0
1
7
0
2
0
1
0
样例数据 2
见下发文件。
本题采用捆绑测试。
对于所有测试点,满足 1 ≤ n , m ≤ 2 × 1 0 5 , 1 ≤ r t ≤ n 1\le n,m \le 2\times 10^5,1\le rt \le n 1≤n,m≤2×105,1≤rt≤n 。
子任务编号 | n , m n,m n,m | 分值 |
---|---|---|
1 1 1 | ≤ 200 \le 200 ≤200 | 5 5 5 |
2 2 2 | ≤ 2000 \le 2000 ≤2000 | 20 20 20 |
3 3 3 | ≤ 5 × 1 0 4 \le 5\times 10^4 ≤5×104 | 35 35 35 |
4 4 4 | ≤ 2 × 1 0 5 \le 2\times 10^5 ≤2×105 | 40 40 40 |
提示
本题时间限制为 2S
,请选手注意 IO
用时。
题解:
首先,多组询问区间,容易想到 莫队。
但
O
(
n
n
)
O(n\sqrt n)
O(nn) 的时间复杂度只能支持
O
(
1
)
O(1)
O(1) 修改。
但查询可以
O
(
n
)
O(\sqrt n)
O(n),这启发我们树链剖分。
但普通的重儿子太过平均,不如把子树大小前
O
(
n
)
O(\sqrt n)
O(n) 个全部划为重儿子。
这样就能维护了。
#include<bits/stdc++.h>
#define N 400005
typedef long long ll;
using namespace std;
inline int read(){
int x=0,f=1;char s=getchar();
while(s<'0'||s>'9'){if(s=='-')f=-1;s=getchar();}
while(s>='0'&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
return x*f;
}
int tot,head[N],ver[N<<1],nex[N<<1];
int dfn[N],dfs_num,sz[N],fa[N],st[N],ed[N];
inline void add(int x,int y){
nex[++tot]=head[x];head[x]=tot;ver[tot]=y;
}
void dfs(int x,int las){
dfn[x]=++dfs_num;
sz[x]=1;st[x]=dfs_num;
for(int i=head[x];i;i=nex[i]){
int y=ver[i];
if(y==las)continue;
fa[y]=x;
dfs(y,x);
sz[x]+=sz[y];
}
ed[x]=dfs_num;
}
int sq,i_he[N],sqm;
priority_queue<pair<int, int> > q;
vector<int> he[N];
ll sum[N],del[N];
struct node{
int l,r,x,id;
}que[N];
bool cmp3(node a,node b){return a.id<b.id;}
bool cmp(node a,node b){if(a.l/sq==b.l/sq)return a.r<b.r;return a.l<b.l;}
int nex_li[N],n;
ll li_del[N],li_sz[N];
void fin_li(int x,int las){
if(!i_he[x])nex_li[x]=x;
else nex_li[x]=nex_li[las];
for(int i=head[x];i;i=nex[i]){
int y=ver[i];
if(y==las)continue;
fin_li(y,x);
}
}
void work(int x,int val){
if(x==0)return ;
x=nex_li[x];
li_del[fa[x]]-=li_sz[x]*li_sz[x];
li_sz[x]+=val;
li_del[fa[x]]+=li_sz[x]*li_sz[x];
work(fa[x],val);
}
ll sum_ou[N],sum_in[N];
struct Node{
int x,id,l,val;
}op[N<<1];
bool cmp2(Node a,Node b){return a.l<b.l;}
vector<ll> p_sz[N];
inline void test(){
for(int i=1;i<=3;++i)cout<<sum_ou[i]<<" ";cout<<endl;
for(int i=1;i<=5;++i)cout<<sum_in[i]<<" ";cout<<endl;
}
inline ll ge(int x){//dfn
if(x==0)return 0;
ll ans=sum_ou[(x-1)/sq]+sum_in[x];
return ans;
}
inline void ch(int x){//i
x=dfn[x];
for(int i=(x-1)/sq+1;i<=(n-1)/sq+1;++i)sum_ou[i]++;
for(int i=x;i<=min(n,((x-1)/sq+1)*sq);++i)sum_in[i]++;
}
int main(){
// freopen("lca2.in","r",stdin);
// freopen("data.in","r",stdin);
// freopen("lca.out","w",stdout);
n=read();
int m=read(),rt=read();
sq=sqrt(n);
for(int i=1;i<n;++i){
int x=read(),y=read();add(x,y);add(y,x);
}
dfs(rt,0);
// cout<<sq<<endl;
for(int i=1;i<=n;++i){
//while(!q.empty())q.pop();
int cnt=0;
for(int j=head[i];j;j=nex[j]){
int y=ver[j];
if(y==fa[i])continue;
if(cnt<sq)q.push(make_pair(-sz[y],y)),++cnt;
else{
if(sz[y]>-(q.top().first))q.pop(),q.push(make_pair(-sz[y],y));
}
}
while(!q.empty()){
int x=q.top().second;
i_he[x]=1;he[i].push_back(x);q.pop();
}
}
for(int i=1;i<=m;++i){
que[i].l=read(),que[i].r=read(),que[i].x=read();que[i].id=i;
}
fin_li(rt,0);
int l=0,r=0;
sort(que+1,que+1+m,cmp);
for(int i=1;i<=m;++i){
int L=que[i].l,R=que[i].r,x=que[i].x;
while(l<L)work(l++,-1);
while(l>L)work(--l,1);
while(r>R)work(r--,-1);
while(r<R)work(++r,1);
//sum[que[i].id]+=li_sz[x];
del[que[i].id]+=li_del[x];
}
for(int i=1;i<=m;++i){
op[i].x=que[i].x,op[i].id=que[i].id,op[i].l=que[i].l-1,op[i].val=-1;
op[i+m].x=que[i].x,op[i+m].id=que[i].id,op[i+m].l=que[i].r,op[i+m].val=1;
}
sort(op+1,op+1+2*m,cmp2);
//cout<<sum[1]<<" "<<del[1]<<endl;
// cout<<st[10]<<" "<<ed[10]<<endl;
for(int i=1,j=1;i<=n;++i){
ch(i);
//cout<<dfn[i]<<":"<<endl;test();
while(j<=2*m&&op[j].l<=i){
//if(op[j].l<i){++j;continue;}
int x=op[j].x,id=op[j].id,val=op[j].val;
//cout<<id<<" "<<x<<" "<<val<<" "<<op[j].l<<endl;
if(op[j].l==i)sum[id]+=val*(ge(ed[x])-ge(st[x]-1));
//cout<<sum[id]<<endl;
//cout<<st[x]<<" "<<ed[x]<<" "<<ge(ed[x])-ge(st[x]-1)<<endl;
for(int k=0;k<he[x].size();++k){
int y=he[x][k];
ll now=ge(ed[y])-ge(st[y]-1);
if(val==-1){if(op[j].l==i)p_sz[id].push_back(-now);else p_sz[id].push_back(0);}
else p_sz[id][k]+=now,del[id]+=p_sz[id][k]*p_sz[id][k];
//sum[id]+=val*now;
}
++j;
}
}
//cout<<sum[1]<<" "<<del[1]<<endl;
sort(que+1,que+1+m,cmp3);
for(int i=1;i<=m;++i){
ll ans=0;
//if(que[i].x>=que[i].l&&que[i].x<=que[i].r)ans=sum[i]*(sum[i]-1);
/*else*/ ans=sum[i]*sum[i];
printf("%lld\n",(ans-del[i])/2);
}
return 0;
}
/*
5 1 1
1 2
1 3
1 4
1 5
1 2 1
*/