當然題目的測試集很水,是可以暴力過的。
那有沒有其他思路呢? 我們可以嘗試從簡化公式入手,看看能否降低時間消耗。
對於Max而言,我們可以透過對v進行排序,一旦v是有序的,我們就可以知道
Vi < Vj, where i < j
那麼棘手的就剩下絕對值符號了。
那有什麼辦法可以幫我們找出兩段區間內的大小關係呢?
這裡可以試一下從歸併排序的理念出發。
歸併排序怎麼幫助我們呢?
他可以幫你把區間分成兩半,並且兩半都是服從升序關係的。
並且最重要的是,我們是遞歸到底部再返回來,就算排序了,也不會打亂左邊的v值永遠比右邊的v值要小這個原則。
假設遞歸層數第 i 層 [1, 3, 5 | 6, 7, 8]
遞歸層數第 i + 1層 [1 | 3, 5] [6 | 7, 8]
你不難發現i + 1層中就算排序了,他回到第i層元素並不會跨越 中間的劃分線。
根據這個神奇的特性,我們可以繼續套用剛剛推出的公式了。
由於使用了歸併排序,我們知道 兩邊區間都是升序的,那麼一定存在一個分界點,左邊比右邊小。
我們先留意公式的一部分。注意,這個時候的k和j都是全局的,也就是說他還沒把把歸併排序的思路融入近這個公式。
那我們根據上面的推論,嘗試把這個公式簡化。
再進行拆解
那這裡我們就開始疑惑了,到底再歸併排序的時候,這個i,j,p,k到底應該代表誰呢?
我們從v的角度出發,由於我們永遠先選v大的,也就是說我們是要從右半區間進行循環先的。
然後我們再從左區間找到p,p是一條劃分線,p的左側也就是從left 到 mid,都是要比a[i]小的,右側是比a[i]大的。
外循環是僅限右區間不再像一開始的整個數組循環,內循環是去左邊區間找大小。
所以外循環和內循環除了比較大小外就沒有任何干涉了,而v是跟右區間的,所以他可以搬出來,像下面的公式。
完整版的
我們簡化一下。
在代碼中如何實現呢?
int p = left;
for(int i = mid + 1; i <= right; i++){
while(p <= mid && a[p] < a[i]){
xxxxx;
}
ans+= xxxxxxx
}
這兩項也很好處理,我們首先求一個xi的總和,然後減掉掉p的前綴和,那麼就的到後綴和了。
long long sum = 0;
for(int i = left; i <= mid; i++){
sum += a[i].x;
}
int p = left;
long long s1 = 0, s2 = sum;
for(int i = mid + 1; i <= right; i++){
while(p <= mid && a[p].x < a[i].x){
s2 -= a[p].x;
s1 += a[p].x;
p++;
}
ans+= a[i].v * ((2 * p - left - mid - 1) * a[i].x - s1 + s2);
}
歸併排序一起:
#include <iostream>
#include <algorithm>
using namespace std;
int n;
long long ans = 0;
struct node{
int v, x;
}a[100500], temp[100500];
bool cmp(node a, node b){
if(a.v < b.v)return true;
else return false;
}
void merge(int left, int right){
if(left >= right)return;
int mid = (left + right)/2;
merge(left, mid);
merge(mid + 1, right);
long long sum = 0;
for(int i = left; i <= mid; i++){
sum += a[i].x;
}
int p = left;
long long s1 = 0, s2 = sum;
for(int i = mid + 1; i <= right; i++){
while(p <= mid && a[p].x < a[i].x){
s2 -= a[p].x;
s1 += a[p].x;
p++;
}
ans+= a[i].v * ((2 * p - left - mid - 1) * a[i].x - s1 + s2);
}
for(int i = left; i <= right; i++){
temp[i] = a[i];
}
int i = left, j = mid + 1;
for(int curr = left; curr <= right; curr++){
if(i >= mid + 1)a[curr] = temp[j++];
else if(j > right) a[curr] = temp[i++];
else if(temp[i].x <= temp[j].x) a[curr] = temp[i++];
else a[curr] = temp[j++];
}
return;
}
int main(){
cin >> n;
for(int i = 1; i <= n; i++){
cin >> a[i].v >> a[i].x;
}
sort(a + 1, a + n + 1, cmp);
merge(1, n);
cout << ans << endl;
return 0;
}