Problem
Hint
对于 30% 的数据,n ≤ 3000;
对于另外 20% 的数据,数列 a 为随机生成;
对于 100% 的数据,1 ≤ n ≤ 3 × 10^5 , 1 ≤ k ≤ 10^6 , 1 ≤ ai ≤ 10^9。
Solution
- 考虑分治。对于区间[l,r],我们找出其中最大值的位置m,则可以计算出左端点在[l,m],右端点在[m,r]的合法区间数量,然后分治处理子区间[l,m-1],[m+1,r]。
- 对于找最大值,我们显然不能暴力枚。因为如果最大值一直在l(在r同理),则分治时一直将l++,那么区间长度即为n~1,因此,时间复杂度为 O(n2) O ( n 2 ) 级别。
- 注意到并不会修改ai,因此,可以考虑RMQ。可以用ST算法 O(nlog2n) O ( n l o g 2 n ) 预处理一下,然后对于每个区间, O(1) O ( 1 ) 找出最大值的位置m。
- 若m-l<r-m,我们可以暴枚左端点,再在右边查询合法的右端点的数量;否则枚举右端点,然后再左边查询合法的左端点的数量。
- 分析一下时间复杂度。若一个位置被枚举到了一次,则它所在的区间至少会缩减成原来的 12 1 2 ,因此一个位置最多被枚举 O(log2n) O ( l o g 2 n ) 次。
- 因此,暴枚左/右端点的复杂度为 O(nlog2n) O ( n l o g 2 n ) 。
- 先推一波式子。若区间[x,y]合法,当且仅当: sum[y]−sum[x−1]−am≡0(mod k) s u m [ y ] − s u m [ x − 1 ] − a m ≡ 0 ( m o d k ) 。
- 若我们枚举了一个左端点x,则合法的y必须满足:
m≤y≤r∧sum[y]≡sum[x−1]+am(mod k) m ≤ y ≤ r ∧ s u m [ y ] ≡ s u m [ x − 1 ] + a m ( m o d k )
- 若我们枚举了y,则x须满足:
l≤x≤m∧sum[x−1]=sum[y]−am(mod k) l ≤ x ≤ m ∧ s u m [ x − 1 ] = s u m [ y ] − a m ( m o d k )
- 于是,我们得到了
O(nlog2n)
O
(
n
l
o
g
2
n
)
个形如“x在区间[l,r]中出现了几次”的询问。
于是这道题变成了主席树模板题。可以离线处理。
- 对于每个形如“x在区间[l,r]中出现了几次”的询问,我们可以将其插到第r、第l-1个位置。
- 我们顺序扫一遍,维护一个桶cnt。若当前我们扫到i,则cnt[x]表示x在区间[1,i]出现的次数。
- 于是,我们扫到第l-1个位置时,查询一下,此时贡献为负;扫到第r个位置时,查询一下,此时贡献为正。
时间复杂度: O(nlog2n) O ( n l o g 2 n ) 。
Code
#include <bits/stdc++.h>
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define rep(i,x) for(query *i=x; i; i=i->ne)
using namespace std;
typedef long long ll;
const int N=6e5+1,inf=0x7FFFFFFF;
int i,j,n,k,a[N],sum[N],f[N][19],x,y,cnt[N<<1];
ll ans;
struct query
{
int x,v;
query *ne;
query(int x,int v,query *ne):x(x),v(v),ne(ne){}
}*fin[N];
inline void newquery(int l,int r,int x)
{
fin[r ]=new query(x, 1,fin[r ]);
if(l)fin[l-1]=new query(x,-1,fin[l-1]);
}
void solve(int l,int r)
{
if(l>r) return;
int l2=log2(r-l+1);
x=f[l][l2]; y=f[r-(1<<l2)+1][l2];
int m=( a[x]>a[y] ? x : y );
if(m-l<r-m)
fo(i,l,m) newquery(m,r,(sum[i-1]+a[m])%k);
else fo(i,m,r) newquery(l-1,m-1,(sum[i]-a[m]%k+k)%k);
solve(l,m-1); solve(m+1,r);
}
int main()
{
freopen("interval.in","r",stdin);
freopen("interval.out","w",stdout);
scanf("%d%d",&n,&k);
fo(i,1,n) scanf("%d",&a[i]), sum[i]=(sum[i-1]+a[i])%k, f[i][0]=i;
fo(j,1,18)
fo(i,1,n)
{
if(!f[x=i+(1<<j-1)][j-1]) break;
f[i][j]=( a[f[i][j-1]]>a[f[x][j-1]] ? f[i][j-1] : f[x][j-1]);
}
solve(1,n);
fo(x,0,n)
{
cnt[sum[x]]++;
rep(i,fin[x])
ans+=1ll*cnt[i->x]*i->v;
}
printf("%lld",ans-n);
}