Kanade's trio
Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 524288/524288 K (Java/Others)Total Submission(s): 1079 Accepted Submission(s): 398
There are T test cases.
1≤T≤20
1≤∑n≤5∗105
0≤A[i]<230
For each test case , the first line consists of one integer n ,and the second line consists of n integers which means the array A[1..n]
1 5 1 2 3 4 5
6
题意
给出 A[1..n] , (1<=∑n<=5∗105) , (0<=A[i]<230) ,要求统计三元组 (i,j,k) 的个数使其满足 i<j<k 并且 (A[i]xorA[j])<(A[j]xorA[k]) 。
思路
事先把所有数字插入字典树中,用字典树维护
A[k]
的信息,接着对每一个
A[i]
,枚举其二进制最高位小于
A[k]
的位数,考虑这样一个情况:若当前枚举到了
A[i]
的二进制第5位比
A[k]
小,那么
A[i]
与
A[k]
的第30位到第6位都是相同的,此时就不用考虑
A[j]
的第30位到第6位如何,只考虑第5位的情况就好。
考虑当前位置
A[i]
的情况,若当前位置
A[i]
为0,那么
A[j]
的相同位置要为0才能使两者异或值为0,此时
A[k]
为1,这时满足条件的
A[j]
与
A[k]
对数可以计入答案。当前位置
A[i]
为1的情况同理(
A[j]
为1,
A[k]
为0)。
对于
A[j]
与
A[k]
对数的统计,在插入
A[k]
时,之前插入的数就都成了
A[j]
。因此用一个cnt[i][j]
数组记录下第i位为j的数之前出现了几次,那么在插入时,对于这一位置
A[k]
为0的情况,之前有多少的
A[j]
在这一位为1,就是此时满足条件的
A[j]
的个数。代码里的cnt[i][nxt ^ 1]
就是此时符合条件的
A[j]
个数。
当我们把一个数从字典树中去掉时,也要考虑去掉这个数留下来的统计值。
这题特殊的地方在于,插入是连续的,之后是连续的删除,所以在插入完成后可以把cnt[i][j]
数组清空一遍,用来记录第i位为j的数被删除了几次。
考虑两个方面:
- 一个是这个数作为
A[k]
直接被去掉带来的影响,像之前一样减去其前面已经被删去的
A[j]
的个数(依然是cnt[now][nxt ^ 1]
)就好。(这一步操作在Trie::Insert()
里面,与插入时的操作类似)
- 还有一个是这个数作为
A[j]
带来的影响,因为这个数已经不能和后面的
A[k]
组合产生贡献了,考虑到在统计时,当前位的
A[k]
已经把可以与其组合的
A[j]
个数统计在了sum[tmp]
中,这里面还需去掉被删去的
A[j]
,被删去的
A[j]
已经被统计在了cnt[i][nxt]
中,现有的
A[k]
被存在了val[tmp]
中,这一部分不能被计入答案,相乘,减去。(sum[tmp] - val[tmp] * cnt[i][nxt]
这一步在函数solve()
里)
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <string>
#include <string.h>
#include <map>
#include <set>
#include <queue>
#include <deque>
#include <list>
#include <bitset>
#include <stack>
#include <stdlib.h>
#define lowbit(x) (x&-x)
#define e exp(1.0)
#define eps 1e-8
//ios::sync_with_stdio(false);
// auto start = clock();
// cout << (clock() - start) / (double)CLOCKS_PER_SEC<<endl;
typedef long long ll;
typedef long long LL;
using namespace std;
typedef unsigned long long ull;
const int maxn=5e5+10;
ll a[maxn];
int bits[32];
struct tire
{
int tot,root;
int val[maxn*30],ch[maxn*30][2];
ll sum[maxn*30],cnt[maxn][2];
int newnode()
{
val[tot]=sum[tot]=0;
ch[tot][0]=ch[tot][1]=-1;
return tot++;
}
void init()
{
tot=0;
root=newnode();
memset(cnt,0,sizeof(cnt));
}
void insert(int x,int v)
{
int now=root,nxt;
for(int i=30;i>=0;i--)
{
nxt=!!(x & bits[i]);
if(ch[now][nxt]==-1)
ch[now][nxt]=newnode();
now=ch[now][nxt];
cnt[i][nxt]++;
sum[now]+=v*cnt[i][nxt^1];
val[now]+=v;
}
}
ll solve(int x)
{
ll ret=0;
int now=root,tmp,nxt;
for(int i=30;i>=0;--i)
{
nxt=!!(x & bits[i]);
tmp=ch[now][nxt^1];
now=ch[now][nxt];
if(tmp!=-1)
ret+=sum[tmp]-val[tmp]*cnt[i][nxt];
if(now==-1) break;
}
return ret;
}
}tire;
int n;
int main()
{
ios::sync_with_stdio(false);
int T;
cin>>T;
bits[0]=1;
for(int i=1;i<32;i++)
bits[i]=bits[i-1]<<1;
while(T--)
{
cin>>n;
for(int i=1;i<=n;i++)
cin>>a[i];
ll ans=0;
tire.init();
for(int i=1;i<=n;i++)
tire.insert(a[i],1);
memset(tire.cnt,0,sizeof(tire.cnt));
for(int i=1;i<=n;i++)
{
tire.insert(a[i],-1);
ans+=tire.solve(a[i]);
}
cout<<ans<<endl;
}
return 0;
}