题目链接
http://codeforces.com/problemset/problem/1037/F
题目大意
给出n, k (2≤k≤n≤106) ( 2 ≤ k ≤ n ≤ 10 6 ) , 和长度为n的数组a (0≤ai≤109) ( 0 ≤ a i ≤ 10 9 ) , 求函数值 z(a,n) z ( a , n ) mod 109+7 10 9 + 7 。
function z(array a, integer k):
if length(a) < k:
return 0
else:
b = empty array
ans = 0
for i = 0 .. (length(a) - k):
temp = a[i]
for j = i .. (i + k - 1):
temp = max(temp, a[j])
append temp to the end of b
ans = ans + temp
return ans + z(b, k)
题目思路
z
z
<script type="math/tex" id="MathJax-Element-5">z</script>函数的本质是每次将数组缩短k, a[i]变为以他开头的长度为k的子数组的最大值。
考虑每个a[i]对答案的贡献, 只需找出每个a[i]被计算的次数。
设a[i]的影响区间为[l,r], 即a[l]~a[r]均小于a[i], 在做max操作时会被a[i]覆盖, 则a[i]的贡献次数为长度为k, 2k-1, 3k-2, 4k-3… 的包含i且在[l, r]内的子区间个数。
设cal(l, r)为长度为k, 2k-1, 3k-2, 4k-3…在[l, r]的子区间个数, 则a[i]的贡献次数为cal(l, r) - cal(l, i-1) - cal(i + 1, r)。
对于一个长度为m的线段, 在[l,r]中的方案数为r - (l + m - 1) + 1, 对于cal(l, r)的计算用等差数列求和形式可以O(1)求得。
对于每个a[i]计算期影响区间[l,r]可以用单调栈, 扫两遍O(n)求得。
Code
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <bitset>
#include <map>
#include <stack>
#include <set>
#define ls ch[x][0]
#define rs ch[x][1]
#define ll long long
#define pi pair<int, int>
#define mp make_pair
#define fi first
#define se second
using namespace std;
int gi(){
int ret = 0; char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)){
ret = ret * 10 + c - '0';
c = getchar();
}
return ret;
}
const int N = (int)1e6 + 10;
const int mo = (int)1e9 + 7;
int n, k, a[N], l[N], r[N];
stack<pi > S;
ll cal(int l, int r){
if (l > r) return 0;
if (r - l + 1 - k < 0) return 0;
int m = (r - l + 2 - k) / (k - 1);
ll ret = 1ll * (m + 1) * (r - l + 2 - k) % mo - 1ll * (1 + m) * m / 2 % mo * (k - 1) % mo;
return ret;
}
int main()
{
n = gi(); k = gi();
for (int i = 1; i <= n; i ++) a[i] = gi();
for (int i = 1; i <= n; i ++){
if (i == 1) S.push(mp(a[1], 1)), l[1] = 1;
else {
l[i] = i;
while (S.size() && S.top().fi < a[i]) l[i] = S.top().se, S.pop();
S.push(mp(a[i], l[i]));
}
}
while (S.size()) S.pop();
for (int i = n; i >= 1; i --){
if (i == n) S.push(mp(a[n], n)), r[n] = n;
else{
r[i] = i;
while (S.size() && S.top().fi <= a[i]) r[i] = S.top().se, S.pop();
S.push(mp(a[i], r[i]));
}
}
ll ans = 0;
for (int i = 1; i <= n; i ++){
ll res = cal(l[i], r[i]) - cal(l[i], i - 1) - cal(i + 1, r[i]);
res = res % mo * a[i] % mo;
(ans += res) %= mo;
}
if (ans < 0) ans += mo;
printf("%lld\n", ans);
return 0;
}