题意:
你有一个长度为n的数组a,并且每个数都有一个权值pi,如果要删掉这个数,要付出pi的代价,定义函数f(a)如下:
一开始有一个空数组,
从1到n,如果
∀
j
<
i
(
a
i
>
a
j
)
∀_{j<i}(a_i>a_j)
∀j<i(ai>aj)那么将ai放入数组末尾
返回目标数组
样例:
现在给你长度为m的数组b,问你付出代价最少是多少使得f(a)=b
题解:
好像是第一次做到主席树+dp的题目,有点简单?
有段时间没写代码了,重新确定一下思路:
首先数据范围是5e5,然后又是一个升序,可能是维护一个值或者是数据结构。
然后由于数组b的存在,那么它是一个特定的升序,考虑可以维护到每个位置每个值的状态,然后进行转移。
这样一想就有可能是dp,然后发现状态转移返程是dp[b[i]]=dp[b[i-1]]+花里胡哨的操作
冷静思考发现是
dp[i]=dp[j]+红色区间中a的值大于a[j]的所有数的p的和+红色区间中a的值小于等于a[j]的所有p是负数的p的和。
发现好像是两个主席树,两个主席树太麻烦了,有没有可能优化。
冷静思考发现可以变成这样:
dp[i]=dp[j]+红色区间中a的值大于a[j]的所有p是正数的p的和+红色区间中所有p是负数的p的和。
这样子就变成了一个主席树+一个前缀和
wa了两发是因为两处判断写的有点问题,没有考虑到所有情况。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll inf=1e18;
const int N=5e5+5;
int a[N],pre[N],b[N];
ll p[N],sum[N];
ll s[N*25];
int ls[N*25],rs[N*25],rt[N],tot;
void update(int l,int r,int root,int last,int pos,ll v){
ls[root]=ls[last];
rs[root]=rs[last];
s[root]=s[last]+v;
if(l==r)
return ;
int mid=l+r>>1;
if(mid>=pos)
update(l,mid,ls[root]=++tot,ls[last],pos,v);
else
update(mid+1,r,rs[root]=++tot,rs[last],pos,v);
}
ll query(int l,int r,int root,int last,int ql,int qr){
if(l>=ql&&r<=qr)
return s[root]-s[last];
int mid=l+r>>1;
ll ans=0;
if(mid>=ql)
ans=query(l,mid,ls[root],ls[last],ql,qr);
if(mid<qr)
ans+=query(mid+1,r,rs[root],rs[last],ql,qr);
return ans;
}
ll dp[N];
int pos[N];
int main()
{
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(int i=1;i<=n;i++)
scanf("%lld",&p[i]);
int m;
scanf("%d",&m);
memset(pre,-1,sizeof(pre));
for(int i=1;i<=m;i++)
scanf("%d",&b[i]),pre[b[i]]=b[i-1];
ll ans=inf;
for(int i=1;i<=n;i++){
if(p[i]>=0)
update(1,n,rt[i]=++tot,rt[i-1],a[i],p[i]);
else
rt[i]=rt[i-1],sum[i]=p[i];
sum[i]+=sum[i-1];
dp[i]=inf;
}
for(int i=1;i<=n;i++){
if(~pre[a[i]]){
int v=pre[a[i]];
if(dp[v]>=inf)continue;
ll val=dp[v]+sum[i-1]-sum[pos[v]]+query(1,n,rt[i-1],rt[pos[v]],v+1,n);
if(!pos[a[i]]||val<dp[a[i]]+sum[i]-sum[pos[a[i]]]+query(1,n,rt[i],rt[pos[a[i]]],a[i]+1,n))
dp[a[i]]=val,pos[a[i]]=i;
if(a[i]==b[m])
if(a[i]<n)
ans=min(ans,val+query(1,n,rt[n],rt[i],a[i]+1,n)+sum[n]-sum[i]);
else
ans=min(ans,val+sum[n]-sum[i]);
}
}
if(ans<inf)
printf("YES\n%lld\n",ans);
else
printf("NO\n");
return 0;
}