注意到每次可能多出来的就是 j + 1这个位置的元素,就维护一个堆(其实就是维护一堆数),每次决策取一个最大的c。可以用set实现,set最大值的堆顶是rbegin(),最小值的为begin()
#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
int main(){
//std::ios::sync_with_stdio(false);
//std::cin.tie(nullptr);
ll n, k;
cin >> n >> k;
ll ans = 0;
vector<pair<ll, ll>> a(n + 1);
for (int i = 1; i <= n; i ++ ) {
cin >> a[i].first;
a[i].second = i;
}
set<pair<ll, ll>> se;
for (int i = 1; i <= k; i ++ ) {
se.insert(a[i]);
}
int j = k + 1;
while(se.size())
{
if(j <= n) {
se.insert(a[j]);
}
int pos = se.rbegin()->second; // 迭代器直接访问要用 ->
se.erase(a[pos]); // set删除详见下面注释
ans += (j * j - pos * pos) * a[pos].first;
j ++ ;
}
cout << ans << '\n';
return 0;
}
/*
st.erase( const T item); //prototype 1 删除这个值
or
st.erase(iterator position) //prototype 删除迭代器,注意是迭代器,不是下标
*/