本文同步发布于个人博客
题意简述
有
n
n
n个点每个点有一个颜色
c
o
l
i
col_i
coli,并且对于每个点给出范围
l
i
l_i
li和
r
x
r_x
rx现在要知道对于每个点,集合
{
(
i
,
x
,
j
)
∣
i
<
x
<
j
,
l
x
≤
c
o
l
i
=
c
o
l
j
≤
r
x
}
\{(i,x,j)|i<x<j,l_x\leq col_i=col_j\leq r_x\}
{(i,x,j)∣i<x<j,lx≤coli=colj≤rx}有多少个元素
换句话说,就是要求每一个点两边有多少对颜色相同的点,并且这一对点的颜色要在
l
x
l_x
lx到
r
x
r_x
rx之间
(
1
≤
n
≤
5
⋅
1
0
5
,
1
≤
c
o
l
i
,
l
i
,
r
i
≤
5
⋅
1
0
5
)
(1\leq n\leq 5\cdot 10^5,1\leq col_i,l_i,r_i\leq 5\cdot 10^5)
(1≤n≤5⋅105,1≤coli,li,ri≤5⋅105)
题解
这道题的官方题解描述的比较简洁,所以我在理解上也出现了许多困扰,所以写一篇还算比较详细的解释吧
首先就是我们该如何维护,一开始我所想的是一开始先把所有的点插进去,然后扫描一遍,一开始都是右边集合,然后每次加一个数进左边集合,从右边集合删除,然后查询左边集合的颜色个数之后和右边相乘,但是这样只能去处理点两边有多少对它们的颜色在
l
i
l_i
li和
r
i
r_i
ri之间,因为你要每个颜色单独处理你就得开
c
o
l
col
col个树状数组,无论时间还是空间都很不现实,所以需要改变维护贡献的思路,下面就是这道题目的核心
换一种计算贡献方式,乘法变加法
举个例子,在这个点左边有两个颜色为
1
1
1,右边有三个颜色为
1
1
1的,那么贡献就是
2
⋅
3
2\cdot 3
2⋅3,但我们也可以这么去维护,左边的两个点分别对应了右边的3个点,即把贡献变成
3
+
3
3+3
3+3,我们考虑一下如何动态更新这种维护
首先每一次遇到一个点,那么这个颜色右边的集合就要减
1
1
1,假设我们现在左边集合已经有了
k
k
k个数,右边集合本来有
x
x
x个数,本来它们维护的贡献是
k
⋅
x
k\cdot x
k⋅x,现在右边少了一个点,贡献应该变成
k
⋅
(
x
−
1
)
k\cdot (x-1)
k⋅(x−1)即
k
⋅
x
−
k
k\cdot x -k
k⋅x−k,就是右边每去掉一个点左边原来老的贡献值就要减去左边集合点的个数
右边少掉一个数,同样的,左边就会多出一个数,那么左边集合就会多出一个点的贡献,即加上当前右边集合的个数
现在我们来考虑一下完整的流程,因为统计答案的那个点不被计算进贡献,算的是它左边和右边的贡献,所以每次先从右边集合删除一个数,然后统计当前这个数的答案,之后下一轮这个点就属于左边集合了,所以我们在计算答案后就把这个点更新进左集合,这个每个颜色我们其实只要去操纵一个点的增删就好了,然后我们把颜色整体用树状数组维护,每次查询只要用前缀和相减即
q
u
e
r
y
(
r
i
)
−
q
u
e
r
y
(
l
i
−
1
)
query(r_i)-query(l_i-1)
query(ri)−query(li−1)即可
最后过一下核心的代码
一开始读入数据,然后统计一下每个颜色有的数的个数
for(int i=1;i<=n;i++)
{
read(a[i]),read(l[i]),read(r[i]);
sum[a[i]]++;
}
然后开始扫描,每一次扫描到的 i i i就是当前要输出的答案,它的左边是左集合,右边是右集合,首先这个点单独拿出来,那么右集合就少了一个数,我们上面的sum数组就是动态维护右边集合的大小,然后右边减小本来左边的答案也要减少
sum[a[i]]--;
add(a[i],-cnt[a[i]]);
cnt如上文所述是当前左集合的大小,我们减去这些贡献,然后就可以输出答案
cout<<query(r[i])-query(l[i]-1)<<" ";
在进入下个循环统计下个点答案时把这个点加入左集合,首先左集合点个数cnt+1,然后这个颜色的贡献要加一个右集合的个数
cnt[a[i]]++;
add(a[i],sum[a[i]]);
这样就是一个完整的流程,如果有不清楚的地方可以对照代码模拟一遍样例,边画图边看文字解释应该就能理解怎么更新的了,下面是完整代码(去头文件)
AC代码
const int maxn=500050;
int n;
ll a[maxn],l[maxn],r[maxn],c[maxn];
ll sum[maxn],cnt[maxn];
ll lowbit(ll x){return x&-x;}
void add(int x,int p)
{
while(x<=n)
{
c[x]+=p;
x+=lowbit(x);
}
}
ll query(int x)
{
ll res=0;
while(x)
{
res+=c[x];
x-=lowbit(x);
}
return res;
}
int main()
{
cin>>n;
for(int i=1;i<=n;i++)
{
read(a[i]),read(l[i]),read(r[i]);
sum[a[i]]++;
}
for(int i=1;i<=n;i++)
{
sum[a[i]]--;
add(a[i],-cnt[a[i]]);
cout<<query(r[i])-query(l[i]-1)<<" ";
cnt[a[i]]++;
add(a[i],sum[a[i]]);
}
cout<<endl;
return 0;
}