将递归反着做,每次相当于选择长度 ≥ L ′ \ge L' ≥L′的连续段 x x x删除,替换成 x + 1 x+1 x+1。
对于区间 [ l : r ] [l:r] [l:r]可能属于多个级别,考虑将其定义为最小的级别,即区间最大值 + 1 +1 +1。
那么我们考虑怎么统计答案。记 [ l : r ] [l:r] [l:r]的最大值为 k k k,那么当 < k <k <k的段都被合并后,若 [ l : r ] [l:r] [l:r]对应的串的长度 ≥ L ′ \ge L' ≥L′,那么对答案会有贡献。
如果暴力做我们就用 [ l , r , i ] [l,r,i] [l,r,i]表示区间 [ l : r ] [l:r] [l:r]的连续段 i i i,每次把 i i i对应的区间并起来,再合并到 i + 1 i+1 i+1去,注意这里我们要用 set \text{set} set存储,我们不加证明的指出复杂度为 O ( n log n ) O(n\log n) O(nlogn)。
接下来的做法我是想不到的 我们考虑将问题进行泛化,每个点有
L
i
,
R
i
L_i,R_i
Li,Ri,合法区间
[
l
:
r
]
[l:r]
[l:r]的贡献是
L
l
R
r
L_lR_r
LlRr。
如何将泛化后的问题应用到原序列上?注意到:
1.1
1.1
1.1 对于一个连续段,贡献是
∑
r
−
l
≥
L
L
l
R
r
\sum_{r-l\ge L}L_lR_r
∑r−l≥LLlRr,可以
O
(
l
e
n
)
O(len)
O(len)计算
1.2
1.2
1.2 如果我们将长度为
m
m
m的
k
k
k合并成
⌊
m
L
′
⌋
\lfloor\frac{m}{L'}\rfloor
⌊L′m⌋个
k
+
1
k+1
k+1,相当于一个点代表原序列的一段,那么如果
[
l
,
r
]
⊆
[
L
,
R
]
[l,r]\subseteq[L,R]
[l,r]⊆[L,R],这样的代替是正确的,因为
[
L
,
R
]
[L,R]
[L,R]包含了将
[
l
,
r
]
[l,r]
[l,r]合并时的结果。如果
[
L
,
R
]
⊆
[
l
,
r
]
[L,R]\subseteq [l,r]
[L,R]⊆[l,r]那么我们已经在上一步计算过方案了,我们只需考虑一个端点落在
[
l
,
r
]
[l,r]
[l,r]的情况。假设
L
∈
[
l
,
r
]
L\in [l,r]
L∈[l,r],那么首先会合并
[
L
,
r
]
[L,r]
[L,r],此时会生成
⌊
r
−
L
+
1
L
′
⌋
\lfloor\frac{r-L+1}{L'}\rfloor
⌊L′r−L+1⌋段,我们找到对应位置累加即可。
复杂度 O ( n log n ) O(n\log n) O(nlogn)。实现比较复杂,可以考虑用链表删除元素,用堆每次找到最小的那一段。
#include<bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define pb push_back
#define pii pair<int,int>
#define fi first
#define se second
using namespace std;
int n,K,a[200005],L[200005],R[200005],nxt[200005];
ll res;
priority_queue<pii,vector<pii>,greater<pii>>q;
struct node{
int L,R,x;
};
vector<node>v,ve;
ll calc(vector<node>&v){
ll tot(0),sum(0);
for(int i=0;i<v.size();i++){
tot+=v[i].R*sum;
if(i-K+2>=0)sum+=v[i-K+2].L;
}return tot;
}
signed main(){
cin>>n>>K;for(int i=1;i<=n;i++)cin>>a[i],L[i]=R[i]=1,nxt[i]=i+1,q.push({a[i],i});nxt[n]=0;
if(K==1){cout<<(ll)n*(n+1)/2;return 0;}
while(q.size()){
int l=q.top().se,k=q.top().fi;q.pop();v.clear(),ve.clear();
if(a[l]!=k)continue;
int r=l;v.pb({L[r],R[r],r});
while(nxt[r]&&a[nxt[r]]==k)r=nxt[r],v.pb({L[r],R[r],r});
if(v.size()<K){
for(auto x:v)a[x.x]=0;
continue;
}
res+=calc(v);int sz=v.size()/K;
ve.resize(sz);
for(int i=0;i<v.size();i++){
int tl=(v.size()-i)/K;
if(tl)ve[sz-tl].L+=v[i].L;
int tr=(i+1)/K;
if(tr)ve[tr-1].R+=v[i].R;
}res-=calc(ve);
for(int i=0;i<sz;i++){
L[v[i].x]=ve[i].L,R[v[i].x]=ve[i].R,a[v[i].x]++;
}
for(int i=sz;i<v.size();i++)a[v[i].x]=0;
nxt[v[sz-1].x]=nxt[r];q.push({k+1,l});
}res+=n;
cout<<res;
}