题意
有一个长度为n的字符串,每一位只会是p或j。你需要取出一个子串S(从左到右或从右到左一个一个取出),使得
不管是从左往右还是从右往左取,都保证每时每刻已取出的p的个数不小于j的个数。你需要最大化|S|。
题解
这题一开始想错了一次。。然后打的时候发现搞错了。。
于是又想了一会,才想出来
然后A了之后发现似乎和别人的做法不一样QAQ
似乎复杂一点,至少代码长一点
首先,考虑只有一个方向,不如说从左往右
设前缀和
g[i]
如果对于一个点i,我们以他为起点,如果一个串
(i,j)
如果是合法的,那么就是
i
到
于是我们可以预处理这个,这个东西显然可以用一个单调栈来弄。。
反过来也是一样的做法于是就得到了两组线段
第一组是正着的,每个线段表示,以这个线段的l开头,到r,任意一个位置都是合法的
第二组是倒着的,每个线段表示,以这个线段的r即为,到l,任意一个位置都是合法的
那么可行的串很明显就是两个相交线段的交集
这个的话,你对于B串按L排序,对于A串扫过去,用一个线段树维护一下B串出现的R,然后求一个R-L的差值最大就可以了
时间复杂度
O(nlogn)
#include<cstdio>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<stack>
using namespace std;
const int N=1000005;
const int MAX=(1<<30);
int n;
int len;
char ss[N];
int a[N];
int g[N],h[N];//前缀和 后缀和
void print ()
{
printf("a:");for (int u=1;u<=n;u++) printf("%d ",a[u]);printf("\n");
printf("g:");for (int u=1;u<=n;u++) printf("%d ",g[u]);printf("\n");
printf("h:");for (int u=1;u<=n;u++) printf("%d ",h[u]);printf("\n");
}
int s[N];//这个点下一个比他小的在哪里
stack<int> S;
struct qq
{
int l,r;
}l[N],L[N];//线段
int tot,tot1;
void prepare ()
{
tot=0;tot1=0;
S.push(0);
for (int u=1;u<=n;u++)
{
while (!S.empty())
{
int x=S.top();
if (g[x]<=g[u]) break;
s[x]=u;S.pop();
}
S.push(u);
}
while (!S.empty()) s[S.top()]=n+1,S.pop();
for (int u=1;u<=n;u++)
if (s[u-1]>u)
l[++tot].l=u,l[tot].r=s[u-1]-1;
S.push(n+1);
for (int u=n;u>=1;u--)
{
while (!S.empty())
{
int x=S.top();
if (h[x]<=h[u]) break;
s[x]=u;S.pop();
}
S.push(u);
}
while (!S.empty()) s[S.top()]=0,S.pop();
for (int u=n;u>=1;u--)
if (s[u+1]<u)
L[++tot1].l=s[u+1]+1,L[tot1].r=u;
}
bool cmp (qq a,qq b){return a.l<b.l;}
struct qt
{
int l,r;
int s1,s2;
int c;
}tr[N<<1];int num;
void bt (int l,int r)
{
int a=++num;
tr[a].l=l;tr[a].r=r;
tr[a].c=0;
if (l==r) return ;
int mid=(l+r)>>1;
tr[a].s1=num+1;bt(l,mid);
tr[a].s2=num+1;bt(mid+1,r);
}
void change (int now,int x)
{
if (tr[now].l==tr[now].r) {tr[now].c=x;return;}
int s1=tr[now].s1,s2=tr[now].s2;
int mid=(tr[now].l+tr[now].r)>>1;
if (x<=mid) change (s1,x);
else change(s2,x);
tr[now].c=max(tr[s1].c,tr[s2].c);
}
int find (int now,int l,int r)
{
if (tr[now].l==l&&tr[now].r==r) return tr[now].c;
int s1=tr[now].s1,s2=tr[now].s2;
int mid=(tr[now].l+tr[now].r)>>1;
if (r<=mid) return find(s1,l,r);
else if (l>mid) return find(s2,l,r);
else return max(find(s1,l,mid),find(s2,mid+1,r));
}
void solve ()
{
//l数组的l本来已经有序了
/*for (int u=1;u<=tot;u++) printf("%d %d\n",l[u].l,l[u].r);
printf("\n");
system("pause");*/
sort(L+1,L+1+tot1,cmp);
bt(1,n);
/*for (int u=1;u<=tot1;u++) printf("%d %d\n",L[u].l,L[u].r);
printf("\n");*/
int now=1;
int ans=0;
for (int u=1;u<=tot;u++)
{
while (now<=tot1)
{
if (L[now].l<=l[u].l) change(1,L[now].r);
else break;
now++;
}
int t=find(1,l[u].l,l[u].r);
ans=max(ans,t-l[u].l+1);
}
printf("%d\n",ans);
}
int main()
{
scanf("%d",&n);
scanf("%s",ss+1);
for (int u=1;u<=n;u++)
{
if (ss[u]=='j') a[u]=-1;
if (ss[u]=='p') a[u]=1;
}
for (int u=1;u<=n;u++) g[u]=g[u-1]+a[u];
for (int u=n;u>=1;u--) h[u]=h[u+1]+a[u];
prepare();
solve();
return 0;
}