出题人有两个数组A,B,请你把两个数组归并起来使得$cost=\sum i c_i$最小.
归并要求原数组的数的顺序在新数组中不改变.
贪心水题
对于一段序列$A_i,A_{i+1},...,A_r$, 我们考虑向$A_k,A_{k+1}$间中间插入一个$x$.
贡献为$iA_i+(i+1)A_{i+1}+...+kA_k+(k+1)x+(k+2)A_{k+1}+...+(r+1)A_r$.
由$x$插入到$A_i$左端更优可以得到$\frac{A_i+...+A_k}{k-i+1}<x$.
由$x$插入到$A_r$右端更优可以得到$\frac{A_{k+1}+...+A_r}{r-k}>x$.
也就是说若某段序列的前一部分平均值小于后一部分, 那么这段序列看做一个整体会更优.
所以先将$A,B$按平均值合并后, 再贪心归并, 归并的贪心证明同理.
要注意数组开两倍
#include <iostream>
#include <sstream>
#include <algorithm>
#include <cstdio>
#include <math.h>
#include <set>
#include <map>
#include <queue>
#include <string>
#include <string.h>
#include <bitset>
#define REP(i,a,n) for(int i=a;i<=n;++i)
#define PER(i,a,n) for(int i=n;i>=a;--i)
#define hr putchar(10)
#define pb push_back
#define lc (o<<1)
#define rc (lc|1)
#define mid ((l+r)>>1)
#define ls lc,l,mid
#define rs rc,mid+1,r
#define x first
#define y second
#define io std::ios::sync_with_stdio(false)
#define endl '\n'
#define DB(a) ({REP(__i,1,n) cout<<a[__i]<<' ';hr;})
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
const int P = 1e9+7, P2 = 998244353, INF = 0x3f3f3f3f;
ll gcd(ll a,ll b) {return b?gcd(b,a%b):a;}
ll qpow(ll a,ll n) {ll r=1%P;for (a%=P;n;a=a*a%P,n>>=1)if(n&1)r=r*a%P;return r;}
ll inv(ll x){return x<=1?1:inv(P%x)*(P-P/x)%P;}
inline int rd() {int x=0;char p=getchar();while(p<'0'||p>'9')p=getchar();while(p>='0'&&p<='9')x=x*10+p-'0',p=getchar();return x;}
//head
const int N = 2e5+10;
int n, m, nn, mm, a[N], b[N], c[N];
int c1[N], c2[N];
ll w1[N], w2[N];
void solve(int *a, int *c, ll *w, int n, int &tot) {
REP(i,1,n) {
++tot, c[tot]=1, w[tot]=a[i];
while (tot>1&&w[tot-1]*c[tot]<w[tot]*c[tot-1]) {
w[tot-1]+=w[tot],c[tot-1]+=c[tot];
--tot;
}
}
}
void add(int *a, int l, int r) {
REP(i,l,r) c[++*c]=a[i];
}
int main() {
scanf("%d%d", &n, &m);
REP(i,1,n) scanf("%d", a+i);
REP(i,1,m) scanf("%d", b+i);
solve(a,c1,w1,n,nn);
solve(b,c2,w2,m,mm);
int now = 1, la = 0, lb = 0;
REP(i,1,nn) {
while (now<=mm&&w1[i]*c2[now]<w2[now]*c1[i]) {
add(b,lb+1,lb+c2[now]),lb+=c2[now];
++now;
}
add(a,la+1,la+c1[i]),la+=c1[i];
}
while (now<=mm) {
add(b,lb+1,lb+c2[now]),lb+=c2[now];
++now;
}
ll ans = 0;
REP(i,1,*c) ans+=(ll)i*c[i];
printf("%lld\n", ans);
}