题目
给定两个序列a和b,每个序列中可能含有重复的数字。
一个配对(i,j)是一个好配对当从第一个序列中选出一个数ai,再从第二个序列中选出一个数bj且满足ai>bj。
给出两个序列,问存在多少个好配对。
题目链接: 好配对
有题目要求,知道题目的数据量比较大:a和b中分别最多有10^5种不同数字,每个数字最多有10^4个。因此,要求算法有O(nlogn)的时间复杂度。
一开始使用了两个map,map1为序列a中的数字以及对应的个数构成的数对;map2为对于序列a中的数字x,序列b中小于x的数字的个数。这样在第一次输入序列a,时候创建map1,以及将map2中的value均设置为0;在输入序列b时,若当前读取数值为x,个数为y,从map1的末尾向前查找直到map1中当前的key值小于等于x,在经过的那些(key, value)对中,value均加上y,表示在序列b中小于key值的数字个数增加y个。
最后,从头到尾遍历一遍 map1和map2, 求和map1[key]*map2[key]就得到最终结果。
结果华丽的超时了: 在对b序列中的每个数字,从末尾到首部遍历map1,构成了O(n^2)的复杂度了。。
超时之后,朝着 O(nlogn)的复杂度方向努力:使用平衡二叉树节点维持数值x,节点中等于x的个数,节点所代表的子树的总数字的个数。在读取序列a的时候,构建这棵平衡二叉树,复杂度为O(nlogn);在读取序列b的时候,对b中的每个数字x,从该平衡二叉树上获得大于x的数字的总个数sum(时间复杂度O(logn),最终结果加上 y*sum.
总时间复杂度为 O(nlogn)
平衡二叉树使用treap来实现。
实现
#include<stdio.h>
#include<string.h>
#include<iostream>
#include<string>
#include<set>
#include<map>
#include<vector>
#include<queue>
#include<stack>
#include<unordered_map>
#include<unordered_set>
#include<algorithm>
using namespace std;
struct Node{
int val;
int count;
int sum;
int priority;
Node* childs[2];
Node(){
val = count = sum = 0;
childs[0] = childs[1] = NULL;
priority = rand();
}
void Update(){
sum = count;
if (childs[0])
sum += childs[0]->sum;
if (childs[1])
sum += childs[1]->sum;
}
};
struct Treap{
Node* root;
Treap(){
root = NULL;
}
void Delete(Node*& node){
if (!node)
return;
if (node->childs[0])
Delete(node->childs[0]);
if (node->childs[1])
Delete(node->childs[1]);
delete node;
node = NULL; //注意赋值为NULL,否则在反复使用treap时出错
}
void Rotate(Node*& node, bool dir){
Node* ch = node->childs[dir];
node->childs[dir] = ch->childs[!dir];
ch->childs[!dir] = node;
node->Update(); //注意更新,因为此时修改了树的结构
node = ch;
}
void Insert(Node*& node, int val, int count){
if (node == NULL){
node = new Node();
node->val = val;
node->sum = node->count = count;
return;
}
if (node->val == val){
node->count += count;
node->sum += count;
return;
}
bool ch = node->val < val;
Insert(node->childs[ch], val, count);
if (node->childs[ch]->priority > node->priority){
Rotate(node, ch);
}
node->Update(); //更新,此时修改了树的结构
}
int Bigger(Node* node, int val){
if (!node)
return 0;
if (node->val == val)
return (node->childs[1]? node->childs[1]->sum:0);
else if (node->val < val)
return Bigger(node->childs[1], val);
else{
return (node->childs[1] ? node->childs[1]->sum : 0) + node->count + Bigger(node->childs[0], val);
}
}
};
int main(){
int T, n, m, x, y;
scanf("%d", &T);
Treap treap;
while (T--){
scanf("%d %d", &n, &m);
treap.Delete(treap.root);
for (int i = 0; i < n; i++){
scanf("%d %d", &x, &y);
treap.Insert(treap.root, x, y);
}
long long result = 0;
for (int i = 0; i < m; i++){
scanf("%d %d", &x, &y);
long long int bigger = treap.Bigger(treap.root, x);
result += y*bigger;
}
printf("%lld\n", result);
}
return 0;
}