There are n points on a coordinate axis OX. The i-th point is located at the integer point xi and has a speed vi. It is guaranteed that no two points occupy the same coordinate. All n points move with the constant speed, the coordinate of the i-th point at the moment t (t can be non-integer) is calculated as xi+t⋅vi.
Consider two points i and j. Let d(i,j) be the minimum possible distance between these two points over any possible moments of time (even non-integer). It means that if two points i and j coincide at some moment, the value d(i,j) will be 0.
Your task is to calculate the value ∑1≤i<j≤n d(i,j) (the sum of minimum distances over all pairs of points).
Input
The first line of the input contains one integer n (2≤n≤2⋅105) — the number of points.
The second line of the input contains n integers x1,x2,…,xn (1≤xi≤108), where xi is the initial coordinate of the i-th point. It is guaranteed that all xi are distinct.
The third line of the input contains n integers v1,v2,…,vn (−108≤vi≤108), where vi is the speed of the i-th point.
Output
Print one integer — the value ∑1≤i<j≤n d(i,j) (the sum of minimum distances over all pairs of points).
Examples
Input
3
1 3 2
-100 2 3
Output
3
Input
5
2 1 4 3 5
2 2 2 3 4
Output
19
Input
2
2 1
-3 0
Output
0
思路:这个题目,读懂题是很关键的,我当初就是没读懂题目。。这个最小距离不是按照时间来算的,而是按照点对来算的。如果有一种点对,xi<xj&&vi<vj,这一种永远也不可能相遇而且越来越远的,最初距离就是贡献的。但是剩下的,肯定有一个时间会让他们相遇(不一定是整数)。因此我们用树状数组来处理这一问题。
代码如下:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxx=2e5+100;
struct node{
int x,v;
bool operator<(const node &a)const{
return x<a.x;
}
}p[maxx];
ll num[maxx],sum[maxx];
int b[maxx];
int n;
inline void init()
{
memset(num,0,sizeof(num));
memset(sum,0,sizeof(sum));
}
inline int lowbit(int x){return x&-x;}
inline ll query(ll c[],int x)
{
ll ans=0;
while(x)
{
ans+=c[x];
x-=lowbit(x);
}
return ans;
}
inline void add(ll c[],int x,ll v)
{
while(x<maxx)
{
c[x]+=v;
x+=lowbit(x);
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&p[i].x);
for(int i=1;i<=n;i++) scanf("%d",&p[i].v),b[i]=p[i].v;
sort(p+1,p+1+n);
sort(b+1,b+1+n);
int len=unique(b+1,b+1+n)-b-1;//离散化,因为和v没有直接关系。
init();
ll ans=0;
for(int i=1;i<=n;i++)
{
int pos=lower_bound(b+1,b+1+len,p[i].v)-b;//计算之前有多少个v小于当前v
ans+=query(num,pos)*(ll)p[i].x-query(sum,pos);//计算出个数以及前缀和
add(num,pos,1);
add(sum,pos,p[i].x);//边处理边添加
}
cout<<ans<<endl;
return 0;
}
努力加油a啊,(o)/~