测试地址:Regions
题目大意: 给定一棵
n
n
n个点的有根树,每个点有颜色,
q
q
q个询问,每次询问给出两个颜色
a
,
b
a,b
a,b,表示询问树中有多少对点
(
u
,
v
)
(u,v)
(u,v)使得
u
u
u颜色是
a
a
a,
v
v
v颜色是
b
b
b,且
u
u
u是
v
v
v的祖先。
做法: 本题需要用到分块+vector。
好题。传统数据结构貌似做不了这道题,于是想到分块。
很容易想到两种暴力:
1.将
a
a
a相同的询问一起处理,对所有点统计合法的祖先数目计算。
2.将
b
b
b相同的询问一起处理,对所有点统计合法的子孙数目计算。
上面两种算法的总时间复杂度都是
O
(
r
n
)
O(rn)
O(rn)的,无法通过此题。但我们想到一个常用的分块思路:设定阈值分类讨论。于是在这里,我们对
b
b
b这种颜色的出现次数
B
B
B进行讨论。
当
B
>
n
B>\sqrt n
B>n时,满足这种条件的
b
b
b只有
n
\sqrt n
n种,于是采用上面第二种暴力,时间复杂度
O
(
n
n
)
O(n\sqrt n)
O(nn)。
当
B
≤
n
B\le \sqrt n
B≤n时,每个询问涉及到的
b
b
b颜色的点最多有
n
\sqrt n
n个,所以我们可以在每种颜色上挂一个vector(在每个点上挂空间会爆炸),存储
b
b
b为这种颜色的询问,然后一次DFS,DFS的同时维护从根到当前点的路径上各颜色点出现的次数,当走到一个
B
≤
n
B\le \sqrt n
B≤n的点时,对这种颜色所影响的询问进行更新。我们发现这差不多就是第一种暴力的思路,只不过我们同时处理了所有的
a
a
a。因为有
q
q
q个询问,每个询问被
n
\sqrt n
n个点影响,所以时间复杂度为
O
(
q
n
)
O(q\sqrt n)
O(qn)。
于是我们就解决了这一题,时间复杂度为
O
(
(
n
+
q
)
n
)
O((n+q)\sqrt n)
O((n+q)n)。
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,R,q,r[200010],first[200010]={0},tot=0,blocksiz;
int opx[200010],opy[200010];
ll ans[200010]={0},cnt[200010]={0},sum[25010]={0};
vector<int> pos[25010],posq[25010];
struct edge
{
int v,next;
}e[200010];
void insert(int a,int b)
{
e[++tot].v=b;
e[tot].next=first[a];
first[a]=tot;
}
void dfs1(int v)
{
int siz=posq[r[v]].size();
if (pos[r[v]].size()<=blocksiz)
{
for(int i=0;i<siz;i++)
ans[posq[r[v]][i]]+=cnt[opx[posq[r[v]][i]]];
}
cnt[r[v]]++;
for(int i=first[v];i;i=e[i].next)
dfs1(e[i].v);
cnt[r[v]]--;
}
void dfs2(int v)
{
ll now=cnt[v];
for(int i=first[v];i;i=e[i].next)
{
dfs2(e[i].v);
cnt[v]+=cnt[e[i].v];
}
sum[r[v]]+=cnt[v]-now;
}
int main()
{
scanf("%d%d%d",&n,&R,&q);
scanf("%d",&r[1]);
pos[r[1]].push_back(1);
for(int i=2;i<=n;i++)
{
int fa;
scanf("%d%d",&fa,&r[i]);
pos[r[i]].push_back(i);
insert(fa,i);
}
blocksiz=(int)sqrt(n+1);
for(int i=1;i<=q;i++)
{
scanf("%d%d",&opx[i],&opy[i]);
posq[opy[i]].push_back(i);
}
dfs1(1);
for(int i=1;i<=R;i++)
if (pos[i].size()>blocksiz)
{
memset(cnt,0,sizeof(cnt));
memset(sum,0,sizeof(sum));
int siz=pos[i].size();
for(int j=0;j<siz;j++)
cnt[pos[i][j]]++;
dfs2(1);
siz=posq[i].size();
for(int j=0;j<siz;j++)
ans[posq[i][j]]=sum[opx[posq[i][j]]];
}
for(int i=1;i<=q;i++)
printf("%lld\n",ans[i]);
return 0;
}