Description
共有m部电影,编号为1~m,第i部电影的好看值为 wi 。在n天之中(从1~n编号)每天会放映一部电影,第i天放映的是第 fi 部。
你可以选择l,r(1<=l<=r<=n),并观看第l,l+1,…,r天内所有的电影。如果同一部电影你观看多于一次,你会感到无聊,于是无法获得这部电影的好看值。所以你希望最大化观看且仅观看过一次的电影的好看值的总和。
Input
- Line 1:两个整数n,m(1<=m<=n<=1000000)。
- Line 2:包含n个整数 f1,f2,…,fn (1≤fi≤m) 。
- Line 3:包含m个整数 w1,w2,…,wm (1≤wj≤106) 。
- 数据范围:
- 有20%的数据,n≤8000;
- 有70%的数据,n≤100000;
- 有100%的数据,n≤1000000。
Output
输出观看且仅观看过一次的电影的好看值的总和的最大值。
Sample Input
9 4
2 3 1 1 4 1 2 4 1
5 3 6 6Sample Output
15
Input Details
观看第2,3, … ,7天内放映的电影,其中看且仅看过一次的电影的编号为2,3,4。
Solution:
本题重点在于下面这句话的运用:“如果同一部电影观看多于一次,就无法获得这部电影的好看值。”
也就是说,对于选定的区间,如果某元素出现重复,那么就不能收集到该元素的权值。一般遇到“区间内元素不重复的问题”我们只需要判断当前元素与下一个相同的元素之间的关系。设当前元素所在位置为pos,下一个相同元素所在位置为nxt[pos],则对于所有 L<pos 的区间,有:
- 若 R<pos||R>=nxt[pos] ,则该区间无法得到val值。
- 若 pos<=R<nxt[pos] ,则该区间可以得到val值。
即对于区间 [pos,nxt[pos]−1] 这段区间的前缀和都会增加一个val值。如果我们定义数据结构维护的是以枚举的 L 作为开头,当前位置结尾的前缀和,那么上述操作就是对该区间进行区间更新。显然,那些位置已经小于L的元素就继续向后推移区间:
#include <bits/stdc++.h>
#define M 1000005
using namespace std;
inline void Rd(int &res){
res=0;char c;
while(c=getchar(),c<48);
do res=(res<<3)+(res<<1)+(c^48);
while(c=getchar(),c>47);
}
/* 维护这样一个数据结构:
* 1)支持查找[L,n]内最大前缀[L,pos]。
* 2)当L向后移动时,删除[L,nxt[L]-1]段的权值,增加[nxt[L],nxt[nxt[L]]-1]
*/
int a[M],w[M];
int nxt[M],pre[M];
bool used[M];
struct Node{
long long sum,add;
}tree[M<<2];
void up(int p){
tree[p].sum=tree[p<<1|1].sum;
if(tree[p].sum<tree[p<<1].sum)tree[p].sum=tree[p<<1].sum;
}
void down(int p){
if(!tree[p].add)return;
tree[p<<1].sum+=tree[p].add;
tree[p<<1|1].sum+=tree[p].add;
tree[p<<1].add+=tree[p].add;
tree[p<<1|1].add+=tree[p].add;
tree[p].add=0;
}
void update(int L,int R,int l,int r,int w,int p){
if(l>r)return;
if(L==l&&R==r){
tree[p].sum+=w;
tree[p].add+=w;
return;
}
down(p);
int mid=L+R>>1;
if(r<=mid)update(L,mid,l,r,w,p<<1);
else if(l>mid)update(mid+1,R,l,r,w,p<<1|1);
else{
update(L,mid,l,mid,w,p<<1);
update(mid+1,R,mid+1,r,w,p<<1|1);
}
up(p);
}
long long query(int L,int R,int l,int r,int p){
if(L==l&&R==r)return tree[p].sum;
down(p);
int mid=L+R>>1;
if(r<=mid)return query(L,mid,l,r,p<<1);
else if(l>mid)return query(mid+1,R,l,r,p<<1|1);
else return max(query(L,mid,l,mid,p<<1),query(mid+1,r,mid+1,r,p<<1|1));
}
int main(){
int n,m;
Rd(n),Rd(m);
for(int i=1;i<=n;i++)Rd(a[i]);
for(int i=1;i<=m;i++)Rd(w[i]);
for(int i=1;i<=n;i++){
if(pre[a[i]])nxt[pre[a[i]]]=i;
nxt[i]=n+1,pre[a[i]]=i;
}
for(int i=1;i<=n;i++)
if(!used[a[i]]){
used[a[i]]=true;
update(1,n,i,nxt[i]-1,w[a[i]],1);
}
long long ans=0;
for(int i=1;i<=n;i++){
ans=max(ans,query(1,n,i,n,1));
update(1,n,i,nxt[i]-1,-w[a[i]],1);
if(nxt[i]!=n+1)update(1,n,nxt[i],nxt[nxt[i]]-1,w[a[i]],1);
}
cout<<ans<<endl;
}