首先考虑怎么安排攻击顺序。显然如果攻击了某台兵器就应该一直连续攻击直到将其破坏,破坏所需时间可以直接算出来,设其为b。假设确定了某个破坏顺序,如果交换相邻两个兵器,显然不会对其他兵器造成影响,两种顺序的代价则分别为a1(b1-1)+a2(b1+b2-1)和a1(b1+b2-1)+a2(b2-1),那么当a2b1<a1b2时先破坏1较优。于是按b/a从小到大排序。
然后考虑怎么秒杀。如果只能秒杀一个显然直接枚举即可。假设已确定要秒杀的是第i个,则需要找到j>i最小化Σax(Bx-1)-ai(Bi-1)-aj(Bj-1)-(An-Ai)bi-(An-Aj)bj+ajbi (Ax=Σay Bx=Σby(y=1~x))(注意是护甲值<=0时被破坏,题面错了,开始这个式子半天没过样例还以为锅了)。设ci=ai(Bi-1)+(An-Ai)bi,则要最小化ajbi-ci-cj。考虑类似斜率优化的东西,若i固定时j比k优,则ajbi-cj<akbi-ck,即bi(aj-ak)<cj-ck,若aj>ak则bi<(cj-ck)/(aj-ak)。这里的a和b都没有单调性,还要保证i编号小于j,那么同样用斜率优化dp的思路,cdq分治,对左边按b从大到小排序,右边按a小到大排序造出上凸壳。只会log^2。
#include<iostream> #include<cstdio> #include<cmath> #include<cstdlib> #include<cstring> #include<algorithm> using namespace std; #define ll long long #define N 300010 char getc(){char c=getchar();while ((c<'A'||c>'Z')&&(c<'a'||c>'z')&&(c<'0'||c>'9')) c=getchar();return c;} int gcd(int n,int m){return m==0?n:gcd(m,n%m);} int read() { int x=0,f=1;char c=getchar(); while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();} while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar(); return x*f; } int n,m,q[N]; ll ans,tot,c[N]; struct data{int x,y;ll z; }a[N],b[N]; bool cmp(const data&a,const data&b) { return b.x*a.y<a.x*b.y; } bool cmp2(const data&a,const data&b) { return a.y>b.y; } bool cmp3(const data&a,const data&b) { return a.x<b.x; } ll calc(int x,int y) { return tot-a[x].z-a[y].z+a[y].x*a[x].y; } double slope(int i,int j) { return (double)(a[j].z-a[i].z)/(a[j].x-a[i].x); } void solve(int l,int r) { if (l>=r) return; int mid=l+r>>1; solve(l,mid); solve(mid+1,r); sort(a+l,a+mid+1,cmp2); sort(a+mid+1,a+r+1,cmp3); int head=0,tail=0; for (int i=mid+1;i<=r;i++) { while (head<tail&&slope(q[tail-1],q[tail])<slope(q[tail],i)) tail--; q[++tail]=i; } for (int i=l;i<=mid;i++) { while (head<tail&&slope(q[head],q[head+1])>a[i].y) head++; ans=min(ans,calc(i,q[head])); } } int main() { #ifndef ONLINE_JUDGE freopen("bzoj4700.in","r",stdin); freopen("bzoj4700.out","w",stdout); const char LL[]="%I64d\n"; #else const char LL[]="%lld\n"; #endif n=read(),m=read(); ll A=0,B=0; for (int i=1;i<=n;i++) (A+=a[i].x=read()),a[i].y=(read()-1)/m+1; sort(a+1,a+n+1,cmp); for (int i=1;i<=n;i++) { A-=a[i].x,B+=a[i].y; a[i].z=a[i].x*(B-1)+A*a[i].y; tot+=a[i].x*(B-1); } ans=tot; solve(1,n); cout<<ans; return 0; }