题意:
给你一个长度为
n
n
n的序列
a
i
a_i
ai,
1
≤
i
≤
n
1\leq i\leq n
1≤i≤n,和
q
q
q组询问,每组询问读入
l
1
,
r
1
,
l
2
,
r
2
l_1,r_1,l_2,r_2
l1,r1,l2,r2,需输出
∑
x
=
0
∞
get
(
l
1
,
r
1
,
x
)
⋅
get
(
l
2
,
r
2
,
x
)
\sum\limits_{x=0}^\infty \text{get}(l_1,r_1,x)\cdot \text{get}(l_2,r_2,x)
x=0∑∞get(l1,r1,x)⋅get(l2,r2,x)
get
(
l
,
r
,
x
)
\text{get}(l,r,x)
get(l,r,x)表示计算区间
[
l
,
r
]
[l,r]
[l,r]中,数字
x
x
x出现了多少次。
题解:
题目要求的东西我们似乎并不能直接求。那么我们就要想办法转化一下,我们先尝试化一下式子。我们不难想到,
g
e
t
(
l
,
r
,
x
)
=
g
e
t
(
1
,
r
,
x
)
−
g
e
t
(
1
,
l
−
1
,
x
)
get(l,r,x)=get(1,r,x)-get(1,l-1,x)
get(l,r,x)=get(1,r,x)−get(1,l−1,x)。我们设
g
e
t
(
1
,
r
,
x
)
get(1,r,x)
get(1,r,x)为
g
(
r
,
x
)
g(r,x)
g(r,x),那么我们带入原式就是
∑
x
=
0
∞
(
g
(
r
1
,
x
)
−
g
(
l
1
−
1
,
x
)
)
⋅
(
g
(
r
2
,
x
)
−
g
(
l
2
−
1
,
x
)
)
\sum\limits_{x=0}^\infty (g(r_1,x)-g(l_1-1,x))\cdot (g(r_2,x)-g(l_2-1,x))
x=0∑∞(g(r1,x)−g(l1−1,x))⋅(g(r2,x)−g(l2−1,x))
=
∑
x
=
0
∞
g
(
r
1
,
x
)
∗
g
(
r
2
,
x
)
−
g
(
r
1
,
x
)
∗
g
(
l
2
−
1
,
x
)
−
g
(
l
1
−
1
,
x
)
∗
g
(
r
2
−
1
,
x
)
+
g
(
l
1
−
1
,
x
)
∗
g
(
l
2
−
1
,
x
)
=\sum\limits_{x=0}^\infty g(r_1,x)*g(r_2,x)-g(r_1,x)*g(l_2-1,x)-g(l_1-1,x)*g(r_2-1,x)+g(l_1-1,x)*g(l_2-1,x)
=x=0∑∞g(r1,x)∗g(r2,x)−g(r1,x)∗g(l2−1,x)−g(l1−1,x)∗g(r2−1,x)+g(l1−1,x)∗g(l2−1,x) 虽然原式不能算,但是这个前缀的式子是可以用莫队来维护的。具体的,我们把一个询问拆成这四项,分别计算出答案并且乘上前面对应的系数。那么我们来讲一下怎么用莫队维护
∑
x
=
0
∞
g
(
l
,
x
)
∗
g
(
r
,
x
)
\sum\limits_{x=0}^\infty g(l,x)*g(r,x)
x=0∑∞g(l,x)∗g(r,x)。我们的方法是,用两个数组分别记录每个数在当前两个区间里各出现了多少次。加进来或者减去一个数的时候的变化量就是另一个数组有多少个相同的元素数。具体写法可以看看代码,思路上就是这样。
复杂度是 O ( n n ) O(n\sqrt{n}) O(nn)的。
代码:
#include <bits/stdc++.h>
using namespace std;
int n,m,a[50010],cnt,sz,pos[50010],l=0,r=0;
long long res,cntl[50010],cntr[50010],ans[50010];
struct node
{
int l,r,opt,id;
}q[300010];
inline int read()
{
int x=0;
char s=getchar();
while(s>'9'||s<'0')
s=getchar();
while(s>='0'&&s<='9')
{
x=x*10+s-'0';
s=getchar();
}
return x;
}
inline int cmp(node x,node y)
{
if(pos[x.l]!=pos[y.l])
return pos[x.l]<pos[y.l];
return x.r<y.r;
}
int main()
{
n=read();
sz=sqrt(n);
for(int i=1;i<=n;++i)
{
a[i]=read();
pos[i]=(i-1)/sz+1;
}
m=read();
for(int i=1;i<=m;++i)
{
int l1=read(),r1=read(),l2=read(),r2=read();
q[++cnt].l=r1;
q[cnt].r=r2;
q[cnt].id=i;
q[cnt].opt=1;
if(q[cnt].l>q[cnt].r)
swap(q[cnt].l,q[cnt].r);
if(l2-1)
{
q[++cnt].l=r1;
q[cnt].r=l2-1;
q[cnt].id=i;
q[cnt].opt=-1;
if(q[cnt].l>q[cnt].r)
swap(q[cnt].l,q[cnt].r);
}
if(l1-1)
{
q[++cnt].l=l1-1;
q[cnt].r=r2;
q[cnt].id=i;
q[cnt].opt=-1;
if(q[cnt].l>q[cnt].r)
swap(q[cnt].l,q[cnt].r);
}
if(l1-1&&l2-1)
{
q[++cnt].l=l1-1;
q[cnt].r=l2-1;
q[cnt].id=i;
q[cnt].opt=1;
if(q[cnt].l>q[cnt].r)
swap(q[cnt].l,q[cnt].r);
}
}
sort(q+1,q+cnt+1,cmp);
for(int i=1;i<=cnt;++i)
{
while(l<q[i].l)
{
++l;
res+=cntr[a[l]];
cntl[a[l]]++;
}
while(l>q[i].l)
{
res-=cntr[a[l]];
cntl[a[l]]--;
--l;
}
while(r<q[i].r)
{
++r;
res+=cntl[a[r]];
cntr[a[r]]++;
}
while(r>q[i].r)
{
res-=cntl[a[r]];
cntr[a[r]]--;
--r;
}
ans[q[i].id]+=res*q[i].opt;
}
for(int i=1;i<=m;++i)
printf("%lld\n",ans[i]);
return 0;
}