orz zhf
题意
有n个病毒,每天每个病毒的体积会变大ai,每天必须且只能消除一个病毒,代价是病毒的体积,每个病毒的初始体积是bi,天数一共有k天,问最小的代价是多少。
数据范围
n , k , a i ≤ 1 0 6 , b i ≤ 1 0 11 n,k,ai\le 10^{6},bi\le 10^{11} n,k,ai≤106,bi≤1011
解法
首先有一个比较显然的
O
(
N
2
)
O(N^2)
O(N2)dp,设f[i][j]表示前i天,一共清除了j个病毒的最小代价,转移显然,然后需要注意的是如果确定了一个选择病毒的集合,那么一定是按ai从大到小的顺序清除病毒。所以可以事先将病毒按ai从大到小排好序.
代码:此处f数组优化了一维,path数组表示的是具体方案
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=2e3+5;
inline int read(){
char c=getchar();int t=0,f=1;
while((!isdigit(c))&&(c!=EOF)){if(c=='-')f=-1;c=getchar();}
while((isdigit(c))&&(c!=EOF)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
return t*f;
}
int n,k;
struct node{
int a,b;
}a[maxn];
bool cmp(node a,node b){
return a.a==b.a?a.b<b.b:a.a>b.a;
}
int f[maxn],path[maxn][maxn];
signed main(){
freopen("12.in","r",stdin);
freopen("12b.out","w",stdout);
n=read(),k=read();
if(k>n)k=n;
for(int i=1;i<=n;i++){
a[i].a=read(),a[i].b=read();
}
sort(a+1,a+1+n,cmp);
memset(f,0x3f,sizeof(f));
f[0]=0;
for(int i=1;i<=n;i++){
for(int j=k;j>=1;j--){
if(f[j-1]+(j-1)*a[i].a+a[i].b<f[j]){
path[i][j]=1;
f[j]=f[j-1]+(j-1)*a[i].a+a[i].b;
}
}
}
/*int x=n,y=k;
while(x>=1){
if(path[x][y]){printf("%lld %lld\n",a[x].a,a[x].b);y--;}
x--;
}*/
printf("%lld\n",f[k]);
return 0;
}
然后观察path数组的性质,可以发现一个病毒如果被选入了f(n,k)的答案集合,那么一定也会被选入f(n,k+1)的答案集合,所以根据这个性质我们考对每个病毒二分一个最早被选入答案集合的时刻。这个是一个笔者认为比较麻烦的问题,具体要解决的问题有:二分的条件,动态维护需要的信息,幸运的是平衡树可以维护这些东西。
首先我们考虑二分的条件:一个病毒被选入答案集合中,可以认为是原先有一个答案集合,现在要将一个病毒插进去,考虑一个位置i,现在答案集合中的该位置的数是a,b。那么这个病毒对答案造成的影响是第一维比a小的都会贡献一个第一维的值,然后这个a会被算(第一维比a大的次数),注意这里第一维和a相等的随便排就可以,可以不用特意处理。然后新加入的病毒产生的贡献是类似的,但是注意新加入的病毒和原有的答案集合之间的信息并不好维护,所以我们也需要先将病毒按a从大到小排好序,这样,每次新加入的病毒一定比答案集合中其它的病毒的第一维更小,就会方便很多。
然后有关原来的答案集合,每个病毒维护两个信息,比它第一维小的病毒数,以及这些病毒的第一维的数值之和,这个可以用平衡树简单维护。然后我们就有了代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e6+5;
inline ll read(){
char c=getchar();ll t=0,f=1;
while((!isdigit(c))&&(c!=EOF)){if(c=='-')f=-1;c=getchar();}
while((isdigit(c))&&(c!=EOF)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
return t*f;
}
int n,k;
struct node{
int a;
ll b;
}a[maxn];
int b[maxn];
bool cmp(node a,node b){
return a.a==b.a?a.b>b.b:a.a>b.a;
}
int ans;
struct tree{
int val,l,r,sz,tag;
ll tot,tag2;
}t[maxn];
int tot,rt,p1,p2;
#define ls t[rt].l
#define rs t[rt].r
inline void pushup(int rt){
t[rt].sz=t[ls].sz+t[rs].sz+1;
}
inline int newnode(int val){tot++;
t[tot].val=val;t[tot].l=t[tot].r=0;t[tot].sz=1;t[tot].tag=t[tot].tot=t[tot].tag2=0;
return tot;
}
inline void pushdown(int rt){
if(t[rt].tag){
t[ls].tag+=t[rt].tag;
t[rs].tag+=t[rt].tag;
t[ls].tot-=1ll*t[rt].tag*a[t[ls].val].a;
t[rs].tot-=1ll*t[rt].tag*a[t[rs].val].a;
t[rt].tag=0;
}
if(t[rt].tag2){
t[ls].tag2+=t[rt].tag2;
t[rs].tag2+=t[rt].tag2;
t[rs].tot+=t[rt].tag2;
t[ls].tot+=t[rt].tag2;
t[rt].tag2=0;
}
}
inline void split(int rt,int &l,int &r,int k){
if(!rt){l=r=0;return ;}
pushdown(rt);
if(t[ls].sz>=k){
r=rt;
split(ls,l,ls,k);
}
else{
l=rt;
split(rs,rs,r,k-t[ls].sz-1);
}
pushup(rt);
}
inline int merge(int x,int y){
if((!x)||(!y))return x^y;
pushdown(x);
pushdown(y);
if(t[x].sz>t[y].sz){
t[x].r=merge(t[x].r,y);
pushup(x);
return x;
}
else{
t[y].l=merge(x,t[y].l);
pushup(y);
return y;
}
}
inline int find(int k){
int u=rt,num=0;
while(1){
if(!u)return num;
pushdown(u);
int va=t[u].val;
ll tmp1=1ll*a[va].a*(num+t[t[u].l].sz)+a[va].b+t[u].tot,tmp2=1ll*a[k].a*(num+t[t[u].l].sz)+a[k].b;
if(tmp1>tmp2){
u=t[u].l;
}
else if(tmp1==tmp2){
return num;
}
else{
num=num+t[t[u].l].sz+1;
u=t[u].r;
}
}
}
inline void modify(int rt,int k,int val){
pushdown(rt);
if(t[ls].sz>=k){t[rs].tag++;t[rs].tot+=val-a[t[rs].val].a;t[rs].tag2+=val;t[rt].tot+=val-a[t[rt].val].a;modify(ls,k,val);}
else if(t[ls].sz+1==k){t[rt].tot+=val-a[t[rt].val].a;t[rs].tag++;t[rs].tag2+=val;t[rs].tot+=val-a[t[rs].val].a;return ;}
else modify(rs,k-t[ls].sz-1,val);
}
void insert(int i){
int sz=find(i);
split(rt,p1,p2,sz);
rt=merge(p1,merge(newnode(i),p2));
if(t[rt].sz>=sz+2)
modify(rt,sz+2,a[i].a);
}
inline int query(int k){
int u=rt;
while(1){
if(t[t[u].l].sz>=k)u=t[u].l;
else if(t[t[u].l].sz+1==k)return t[u].val;
else{k-=t[t[u].l].sz+1;u=t[u].r;}
}
}
signed main(){
//freopen("12.in","r",stdin);
//freopen("12.out","w",stdout);
n=read(),k=read();
if(k>n)k=n;
for(int i=1;i<=n;i++){
a[i].a=read(),a[i].b=read();
}
sort(a+1,a+1+n,cmp);
for(int i=1;i<=n;i++){
insert(i);
}
ll ans=0;
for(int i=1;i<=k;i++){
int tmp=query(i);
b[i]=a[tmp].a;
ans=ans+a[tmp].b;
}
sort(b+1,b+1+k);
for(int i=k;i>=1;i--){
ans=ans+1ll*b[i]*(k-i);
}
printf("%lld\n",ans);
return 0;
}
//经历了长时间的卡常,终于卡过了。。。
时间复杂度
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)