In this problem, you need to write a segment tree to find the segment with the maximum sum.
code
#include <iostream>
#include <cctype>
#include <cstring>
#include <cstdio>
#include <vector>
#include <set>
#include <algorithm>
#define forn(i, n) for (int i = 0; i < int(n); i++)
using namespace std;
using ll = long long;
inline int read(){
int x = 0, op = 1; char ch = getchar();
while (!isdigit(ch)){ if (ch == '-') op = -1; ch = getchar();}
while (isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48), ch = getchar();}
return x * op;
}
struct item
{
ll seg, pre, suf, sum;
};
struct segtree
{
int size;
vector<item> v;
const item NEUTURAL_ELEMNT = {0, 0, 0, 0};
item merge(item a, item b){
return{
max(a.seg, max(b.seg, a.suf + b.pre)),
max(a.pre, a.sum + b.pre),
max(b.suf, a.suf + b.sum),
a.sum + b.sum
};
}
item single(int val){
if (val > 0)
{
return {val, val, val, val};
}else{
return {0, 0, 0, val};
}
}
void init(int n){
size = 1;
while(size < n) size <<= 1;
v.resize(size << 1);
}
void build(vector<int> &a, int rt, int l, int r){
if (l == r)
{
if (l < a.size()) v[rt] = single(a[l]);
return;
}
int m = (l + r) >> 1;
build(a, rt << 1, l, m);
build(a, rt << 1|1, m + 1, r);
v[rt] = merge(v[rt << 1], v[rt << 1|1]);
}
void build(vector<int> &a){
build(a, 1, 1, size);
}
void set(int rt, int pos, int val, int l, int r){
if (l == r)
{
v[rt] = single(val);
return;
}
int m = (l + r) >> 1;
if (pos <= m)
{
set(rt << 1, pos, val, l, m);
}else{
set(rt << 1|1, pos, val, m + 1, r);
}
v[rt] = merge(v[rt << 1], v[rt << 1|1]);
}
void set(int pos, int val){
set(1, pos, val, 1, size);
}
/*item calc(int rt, int l, int r, int lb, int rb){
if (l > rb || r < lb) return NEUTURAL_ELEMNT;
if (l >= lb && r <= rb){
return v[rt];
}
int m = (l + r) >> 1;
item s1 = calc(rt << 1, l, m, lb, rb);
item s2 = calc(rt << 1|1, m + 1, r, lb, rb);
return merge(s1, s2);
}
*/
item calc(int rt, int l, int r, int lb, int rb){
//if (l > rb || r < lb) return NEUTURAL_ELEMNT;
if (l >= lb && r <= rb){
return v[rt];
}
int m = (l + r) >> 1;
item res = NEUTURAL_ELEMNT;
if (m >= lb)
{
res = merge(res, calc(rt << 1, l, m, lb, rb));
}
if (m < rb)
{
res = merge(res, calc(rt << 1|1, m + 1, r, lb, rb));
}
return res;
}
item calc(int lb, int rb){
return calc(1, 1, size, lb, rb);
}
};
int main(int argc, char const *argv[])
{
int n = read(), m = read();
segtree st;
st.init(n);
std::vector<int> a;
a.resize(n + 1);
for (int i = 1; i <= n; ++i)
{
a[i] = read();
}
st.build(a);
int i, w;
while(m--){
i = read(), w = read();
printf("%lld\n", st.calc(1, n).seg);
st.set(i + 1, w);
}
printf("%lld\n", st.calc(1, n).seg);
return 0;
}