Codeforces Round #851 (Div. 2) E. Sum Over Zero
Denote p p p as the prefix sum of a a a.For a segment [ x + 1 , y ] [x + 1,y] [x+1,y] to be an element of S S S, p x ≤ p y p_x ≤ p_y px≤py should be satisfied.
Lets denote d p i dp_i dpi; as the maximum value of the sum of length of segment smaller than i i i in S S S. Segment [ x , y ] [x, y] [x,y] is smaller than i i i if y ≤ i y≤i y≤i.lf there is no segment ending at i i i, d p i = d p i − 1 dp_i =dp_{i-1} dpi=dpi−1. f there is segment [ k + 1 , i ] [k + 1,i] [k+1,i] in S S S, d p i dp_i dpi = m a x p k ≤ p i max_{pk≤pi} maxpk≤pi ( d p k + i − k ) (dp_k+i - k) (dpk+i−k). By summing up, d p i dp_i dpi = = =max$ ( d p i − 1 , m a x p k < p i ( d p k + i − k ) (dp_{i-1}, max_{pk<pi}(dp_k+i - k) (dpi−1,maxpk<pi(dpk+i−k) with this dp, we get an O ( N 2 ) O(N^2) O(N2) solution.
Nowlets try to speed up the dp transtion using segment ree. First,use coordinate compression on p i {p_i} pi since we only see whether oneprefix sum is bigger than the other. We will maintain a segment tree that stores d p k − k dp_k - k dpk−k in position p k p_k pk.
Let’s find d p i dp_i dpi in order of i i i. d p i dp_i dpi = = = m a x ( d p i − 1 , m a x p k ≤ p i ( d p k − k ) + i ) max(dp_{i-1}, max_{pk≤pi} (dp_k - k)+i) max(dpi−1,maxpk≤pi(dpk−k)+i) We can solve m a x p k ≤ p i ( d p k − k ) max_{pk ≤ pi} (dp_k-k) maxpk≤pi(dpk−k) by range query [ 0 , p i ] [0, p_i] [0,pi] on a segment tree. So we can solve d p i dp_i dpi in O ( l o g N ) O(logN) O(logN) for each i i i.
The entire problem is solved in O ( N l o g N ) O(NlogN) O(NlogN).
There is an alternative solution that maintains pair ( d p k − k , p k ) (dp_k - k, p_k) (dpk−k,pk) monotonically with a set. This soluton also nuns in O ( N l o g N ) O(NlogN) O(NlogN)
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 200010;
int tr[N];
int lowbit(int x)
{
return x&-x;
}
void add(int x,int v)
{
for(int i=x;i<N;i+=lowbit(i))
tr[i]=max(tr[i],v);
}
int query(int x)
{
int res=-0x3f3f3f3f;
for(int i=x;i>0;i-=lowbit(i))
res=max(res,tr[i]);
return res;
}
int main()
{
int n;cin>>n;
vector<int> a(n+1);
for(int i=1;i<=n;i++) cin>>a[i];
vector<LL> s(n+1);
for(int i=1;i<=n;i++) s[i]+=s[i-1]+a[i];
auto v=s;
sort(v.begin(),v.end());
vector<int> dp(n+1);
memset(tr,-0x3f,sizeof tr);
int x=lower_bound(v.begin(),v.end(),s[0])-v.begin()+1;
add(x,0);
for(int i=1;i<=n;i++)
{
x=lower_bound(v.begin(),v.end(),s[i])-v.begin()+1;
dp[i]=max(dp[i-1],i+query(x));
add(x,dp[i]-i);
}
cout<<dp[n]<<endl;
return 0;
}