题面
题意
给出一个长度为n的序列a,将它的所有子区间的gcd计算出来并存入数组b(长度为 n ∗ ( n + 1 ) / 2 n*(n+1)/2 n∗(n+1)/2),并将b区间排序,然后将数组b的所有子区间的和计算出来存入数组c,问数组c的中位数是多少。
做法
首先因为数组a中的所有数都小于等于100000,因此gcd的数字种类也在这个范围内,对数组a可以用倍增求出每个数字作为gcd的次数。
然后考虑二分答案,可以只要统计出数组b中小于等于mid的区间和的个数即可,若数组b的长度很小,则可以直接尺取。可是现在数组b长度很大,但是一共只有至多100000种数字,因此仍然可以尺取。
首先考虑种数字内部的贡献,这部分比较容易计算。
然后考虑左端点为l,右端点为r的贡献,若
∑
i
=
l
r
c
n
t
[
i
]
∗
i
<
=
m
i
d
\sum_{i=l}^{r}cnt[i]*i<=mid
∑i=lrcnt[i]∗i<=mid,则贡献显然为
c
n
t
[
l
]
∗
c
n
t
[
r
]
cnt[l]*cnt[r]
cnt[l]∗cnt[r],用前缀和维护一下即可,比较难处理的是
∑
i
=
l
r
c
n
t
[
i
]
∗
i
>
m
i
d
\sum_{i=l}^{r}cnt[i]*i>mid
∑i=lrcnt[i]∗i>mid的情况。
这种情况的答案可以看作是
∑
0
<
a
<
=
c
n
t
[
l
]
,
0
<
b
<
=
c
n
t
[
r
]
[
a
∗
l
+
b
∗
r
<
=
m
i
d
−
∑
i
=
l
+
1
r
−
1
c
n
t
[
i
]
∗
i
]
\sum_{0<a<=cnt[l],0<b<=cnt[r]}{[a*l+b*r<=mid-\sum_{i=l+1}^{r-1}cnt[i]*i]}
∑0<a<=cnt[l],0<b<=cnt[r][a∗l+b∗r<=mid−∑i=l+1r−1cnt[i]∗i]
右边是一个常数,这里计作c,令
S
(
a
,
b
,
c
)
S(a,b,c)
S(a,b,c)为
a
x
+
b
y
<
=
c
ax+by<=c
ax+by<=c的非负整数解的数量。
则这部分贡献经过容斥之后可以看作是
S
(
l
,
r
,
c
−
l
−
r
)
−
S
(
l
,
r
,
c
−
l
∗
(
c
n
t
[
l
]
+
1
)
−
r
)
−
S
(
l
,
r
,
c
−
l
−
r
∗
(
c
n
t
[
r
]
+
1
)
)
+
S
(
l
,
r
,
c
−
l
∗
(
c
n
t
[
l
]
+
1
)
−
r
∗
(
c
n
t
[
r
]
+
1
)
)
S(l,r,c-l-r)-S(l,r,c-l*(cnt[l]+1)-r)-S(l,r,c-l-r*(cnt[r]+1))+S(l,r,c-l*(cnt[l]+1)-r*(cnt[r]+1))
S(l,r,c−l−r)−S(l,r,c−l∗(cnt[l]+1)−r)−S(l,r,c−l−r∗(cnt[r]+1))+S(l,r,c−l∗(cnt[l]+1)−r∗(cnt[r]+1))
现在考虑计算
S
(
a
,
b
,
c
)
S(a,b,c)
S(a,b,c)
可以发现
S
(
a
,
b
,
c
)
=
c
/
a
+
1
+
f
(
a
,
c
%
a
,
b
,
c
/
a
)
S(a,b,c)=c/a+1+f(a,c \% a,b,c/a)
S(a,b,c)=c/a+1+f(a,c%a,b,c/a)
f
(
a
,
b
,
c
,
n
)
=
∑
i
=
0
n
(
a
∗
i
+
b
)
/
c
f(a,b,c,n)=\sum_{i=0}^{n}{(a*i+b)/c}
f(a,b,c,n)=∑i=0n(a∗i+b)/c
f
f
f函数可以直接用类欧几里得算法求。
代码
#include<bits/stdc++.h>
#define ll long long
#define LG 16
#define N 100100
#define MN 100000
using namespace std;
ll n,m,num[N],cnt[N],qzc[N],qzs[N];
namespace Get
{
ll g[N][20];
inline ll gcd(ll u,ll v)
{
for(;u&&v&&u!=v;)
{
swap(u,v);
u%=v;
}
return max(u,v);
}
void work()
{
ll i,j,k,t,l;
for(i=1;i<=n;i++) g[i][0]=num[i];
for(i=1;i<=LG;i++)
{
for(j=1;j+(1 << (i-1))<=n;j++)
{
g[j][i]=gcd(g[j][i-1],g[j+(1 << (i-1))][i-1]);
}
}
for(i=1;i<=n;i++)
{
for(j=l=i,t=num[i];j<=n;)
{
t=gcd(t,num[j]);
for(k=LG;k>=0;k--)
{
if((j+(1 << k))>n+1) continue;
if(g[j][k]%t==0)
{
j+=(1 << k);
}
}
cnt[t]+=j-l;
l=j;
}
}
}
}
ll f(ll a,ll b,ll c,ll n)
{
if(n<0) return 0;
if(!a) return b/c*(n+1);
if(a>=c || b>=c) return f(a%c,b%c,c,n)+a/c*n*(n+1)/2+b/c*(n+1);
ll m=(a*n+b)/c;
return m*n-f(c,c-b-1,a,m-1);
}
inline ll ask(ll a,ll b,ll c)
{
if(c<0) return 0;
return c/a+1+f(a,c%a,b,c/a);
}
inline ll solve(ll a,ll ca,ll b,ll cb,ll c)
{
if(a*(ca-1)+b*(cb-1)<=c) return ca*cb;
if(c<0) return 0;
return ask(a,b,c)-ask(a,b,c-a*ca)-ask(a,b,c-b*cb)+ask(a,b,c-a*ca-b*cb);
}
inline ll calc(ll u)
{
ll i,j,l,r,t,res=0;
for(i=1;i<=MN;i++)
{
if(!cnt[i]) continue;
t=min(cnt[i],u/i);
res+=cnt[i]*t-t*(t-1)/2;
}
for(l=r=1;l<MN;l++)
{
if(!cnt[l]) continue;
if(l<r) res+=cnt[l]*(qzc[r-1]-qzc[l]);
for(;r<=MN&&qzs[r]-qzs[l]<=u;r++)
{
if(l!=r && cnt[r])
{
res+=solve(l,cnt[l],r,cnt[r],u-(qzs[r-1]-qzs[l])-l-r);
}
}
if(r<=MN)
{
if(l!=r && cnt[r])
{
res+=solve(l,cnt[l],r,cnt[r],u-(qzs[r-1]-qzs[l])-l-r);
}
}
}
return res;
}
int main()
{
ll i,j,l,r,mid;
cin>>n;
for(i=1;i<=n;i++)
{
scanf("%lld",&num[i]);
}
Get::work();
m=(n+1)*n/2;
m=(m+1)*m/2;
m=(m+1)/2;
for(i=1;i<=MN;i++)
{
qzc[i]=qzc[i-1]+cnt[i];
qzs[i]=qzs[i-1]+cnt[i]*i;
}
for(l=1,r=qzs[MN]+1;l<r;)
{
mid=((l+r)>>1);
if(calc(mid)<m) l=mid+1;
else r=mid;
}
cout<<l;
}