Description
You have to restore the wall. The wall consists of N pillars of bricks, the height of the ii-th pillar is initially equal to hihi, the height is measured in number of bricks. After the restoration all the N pillars should have equal heights.
You are allowed the following operations:
- put a brick on top of one pillar, the cost of this operation is A;
- remove a brick from the top of one non-empty pillar, the cost of this operation is R;
- move a brick from the top of one non-empty pillar to the top of another pillar, the cost of this operation is M.
You cannot create additional pillars or ignore some of pre-existing pillars even if their height becomes 0.
What is the minimal total cost of restoration, in other words, what is the minimal total cost to make all the pillars of equal height?
Input
The first line of input contains four integers N, A, R, M (1≤N≤10^5, 0≤A,R,M≤10^4) — the number of pillars and the costs of operations.
The second line contains N integers hi (0≤hi≤10^9) — initial heights of pillars.
Ouput
Print one integer — the minimal cost of restoration.
Examples
input
3 1 100 100
1 3 8
output
12
input
3 100 1 100
1 3 8
output
9
input
3 100 100 1
1 3 8
output
4
input
5 1 2 4
5 5 3 6 5
output
4
input
5 1 2 2
5 5 3 6 5
output
3
题目大意:
给定N堆砖头,每堆砖有hi个,要求用以下三种操作使得所有堆的砖块个数相同(即高度相同),同时要使得花费最小。
1. 在一堆砖上加一块砖,花费为A;
2. 拿掉一堆砖顶上的一块砖,花费为R;
3. 将一块砖从一个堆顶移到另一个堆顶,花费为M;
分析:
假设最终的高度为H,此时的花费最小。
考虑从H开始不断增加高度,那么就会有更多的hi小于H,必然会有更多的1操作,花费会增加;
同理,考虑从H开始不断降低高度,那么必然会有更多的2操作,花费会增加;
所以,可以看出答案是一个下凸函数。
下凸函数求极值,利用三分。
具体解释见代码。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <string>
#include <cmath>
#include <vector>
#include <map>
#include <set>
#include <queue>
#define INF 0x3f3f3f3f
#define mst(a,num) memset(a,num,sizeof a)
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define repd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
typedef long long ll;
typedef pair<int,int> PII;
typedef pair<ll,ll> PLL;
typedef vector<int> VI;
const ll mod = 1e9 + 7;
const int maxn = 100000 + 5;
int h[maxn];
int n,a,r,m;
ll solve(int H){
ll dwn=0,up=0;
ll res=0;
rep(i,1,n){
if(h[i]<H){
dwn+=0ll+H-h[i];
}
else{
up+=0ll+h[i]-H;
}
}
int mn=min(a+r,m);
if(up>dwn){
res+=1ll*dwn*mn;
res+=1ll*(0ll+up-dwn)*r;
}
else{
res+=1ll*up*mn;
res+=1ll*(0ll+dwn-up)*a;
}
return res;
}
int main() {
scanf("%d%d%d%d",&n,&a,&r,&m);
int l=1e9,r=0;
rep(i,1,n){
scanf("%d",h+i);
l=min(l,h[i]);
r=max(r,h[i]);
}
int ansh=0;
ll ans=1e18;
while(l<=r){
int lmid=l+(r-l)/3;
int rmid=r-(r-l)/3;
ll tmp1=solve(lmid);
ll tmp2=solve(rmid);
if(tmp1>tmp2){
ansh=rmid;
ans=tmp2;
l=lmid+1;
}
else{
ansh=lmid;
ans=tmp1;
r=rmid-1;
}
}
// printf("%d\n",ansh);
printf("%lld\n",ans);
return 0;
}