Description
给出 n n 个序列,并从这个序列选出 m m 个构成一个新序列(每个元素是一个序列),问这个新序列的最大子段和
Input
第一行两个整数表示序列个数以及新序列所用序列的个数,之后 n n 行每行首先输入该序列的长度,保证每个序列长度不超过,之后输入该序列,最后输入 m m 个整数表示新序列中所用序列的编号
(1≤n≤50,1≤m≤250000) ( 1 ≤ n ≤ 50 , 1 ≤ m ≤ 250000 )
Output
输出新序列的最大子段和
Sample Input
3 4
3 1 6 -2
2 3 3
2 -5 1
2 3 1 3
Sample Output
9
Solution
首先求出每个序列中元素的最大值 mx[i] m x [ i ] 以及前缀最大值 Pre[i] P r e [ i ] 和后缀最大值 Next[i] N e x t [ i ] ,若整个序列非正则输出元素最大值即可,否则考虑最大子段的构造,必然为一个序列的后缀最大值加上中间几个完整序列再加上一个序列的前缀最大值,以 dp[i] d p [ i ] 表示以这 m m 个序列中第个序列(实际编号为 ai a i )为结尾的最大子段和,那么有转移
dp[i]=max(Next[aj]+sum(j+1,i−1))+Pre[ai],j<i d p [ i ] = m a x ( N e x t [ a j ] + s u m ( j + 1 , i − 1 ) ) + P r e [ a i ] , j < i
其中 sum(j+1,i−1) s u m ( j + 1 , i − 1 ) 表示第 j+1 j + 1 个序列到第 i−1 i − 1 个序列的和
故只要维护 Next[a[j]]+sum(j+1,i−1) N e x t [ a [ j ] ] + s u m ( j + 1 , i − 1 ) 最大值即可,其中 sum s u m 的维护只需每次区间加上一个序列的和,故用线段树解决该区间修改和区间查询最大值的操作,时间复杂度 O(mlogm) O ( m l o g m )
Code
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<ctime>
using namespace std;
typedef long long ll;
typedef pair<int,int>P;
const int INF=0x3f3f3f3f,maxn=250005;
int n,m,a[5005],Val[5005],Pre[5005],Next[5005],b[maxn],mx[5005],res[5005];
#define ls (t<<1)
#define rs ((t<<1)|1)
ll Max[maxn<<2],Lazy[maxn<<2];
void push_up(int t)
{
Max[t]=max(Max[ls],Max[rs]);
}
void build(int l,int r,int t)
{
Max[t]=Lazy[t]=0;
if(l==r)
{
Max[t]=Next[b[l-1]];
return ;
}
int mid=(l+r)/2;
build(l,mid,ls);
build(mid+1,r,rs);
push_up(t);
}
void push_down(int l,int r,int t)
{
if(Lazy[t])
{
Max[ls]+=Lazy[t];
Max[rs]+=Lazy[t];
Lazy[ls]+=Lazy[t];
Lazy[rs]+=Lazy[t];
Lazy[t]=0;
}
}
void update(int L,int R,int l,int r,int t,int val)
{
if(L<=l&&r<=R)
{
Lazy[t]+=val;
Max[t]+=val;
return ;
}
push_down(l,r,t);
int mid=(l+r)/2;
if(L<=mid)update(L,R,l,mid,ls,val);
if(R>mid)update(L,R,mid+1,r,rs,val);
push_up(t);
}
ll query(int L,int R,int l,int r,int t)
{
if(L<=l&&r<=R)return Max[t];
push_down(l,r,t);
int mid=(l+r)/2;
ll ans=0;
if(L<=mid)ans=max(ans,query(L,R,l,mid,ls));
if(R>mid)ans=max(ans,query(L,R,mid+1,r,rs));
push_up(t);
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
int l;
scanf("%d",&l);
for(int j=1;j<=l;j++)scanf("%d",&a[j]);
Pre[i]=Next[i]=0;
int temp=0;
for(int j=1;j<=l;j++)
{
if(j==1)mx[i]=a[j];
else mx[i]=max(mx[i],a[j]);
temp+=a[j];
Pre[i]=max(Pre[i],temp);
}
temp=0;
for(int j=l;j>=1;j--)
{
temp+=a[j];
Next[i]=max(Next[i],temp);
}
Val[i]=temp;
int sum=0;
res[i]=0;
for(int j=1;j<=l;j++)
{
sum+=a[j];
if(sum<0)sum=0;
res[i]=max(res[i],sum);
}
}
int flag=0;
for(int i=1;i<=m;i++)
{
scanf("%d",&b[i]);
if(mx[b[i]]>=0)flag=1;
}
if(!flag)
{
int ans=mx[b[1]];
for(int i=2;i<=m;i++)ans=max(ans,mx[b[i]]);
printf("%d\n",ans);
}
else
{
build(1,m,1);
ll ans=max(Pre[b[1]],Next[b[m]]);
for(int i=1;i<=m;i++)ans=max(ans,(ll)res[b[i]]);
for(int i=2;i<=m;i++)
{
ll temp=query(1,i,1,m,1)+Pre[b[i]];
ans=max(ans,temp);
update(1,i,1,m,1,Val[b[i]]);
}
printf("%I64d\n",ans);
}
return 0;
}