题目
给定一个长度为
n
的非负整数列
a
1
, a
2
, . . . , a
n
和非负整数
x
, 求有多少个非空子序列
1
≤
b
1
< b
2
<
· · ·
< b
k
≤
n
,满足对任意的
(
i, j
) (1
≤
i < j
≤
k
)
都有
a
b
i
⊕
a
b
j
≥
x
。其中
⊕
表示按位异或。
由于这个数可能非常大,你只需要回答这个数除以
998 244 353
的余数即可。
Input
第一行两个整数
n
,
x
(
1
≤
n
≤
3
·
10
5
,
0
≤
x <
2
60
)
。
第二行
n
个整数
a
1
, a
2
, . . . , a
n
(
0
≤
a
i
<
2
60
)
。
题解&思路
我的思路
有可能是dp,不会。。。。
正解
首先可以证明选出的序列按a排序后异或的最小值只可能是相邻的两个数,即令a<b<c,需要证明min(a⊕b,b⊕c)<=a⊕c , 从高位往低位看,如果a,b, c的某一位是相等的,那么一定满足上述条件,如果不相等,因为是从高位到低位,即可以发现只有几种情况:
令a的二进制下从左到右的第i位是ai。bi,ci同理
1.ai=0, bi = 0 , ci = 1 , 满足a⊕b<a⊕c
2. ai = 0 , bi = 1 , ci = 1 , b⊕c<a⊕c
证毕
然后就可以将a排序,然后进行dp即可,方程为:
dp[i] = ∑dp[j] + 1 ( a[j] ⊕a[i] >= x )
但是这个东西是n^2的,需要优化。因为有异或,所以不可能是单调队列之内的东西,考虑用字典树,发现这就是个二叉树,然后在字典树上维护即可,即查找子树的和,根据判断x第j位的大小与a[i]第j位大小判断即可
总结
首先要发现子序列的异或最小值要>=x,然后考虑是哪两个数,发现规律后,就可列出dp式,然后因为有异或,才需要考虑用数据结构优化,字典树可做
代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 3e5 + 3;
const int mod = 998244353;
#define ll long long
int n ,ncnt , trie[MAXN*60][2];
ll X , sum[MAXN*60] , a[MAXN] , ans;
ll mo( ll x ){
if( x >= mod ) x -= mod;
return x;
}
int root;
ll find_( ll p ){
ll ans1 = 0;
root = 0;
for( int i = 60 ; i >= 0 ; i -- ){
if( (X & ( 1ll << i )) ){
if((p & ( 1ll << i )) )
root = trie[root][0];
else{
root = trie[root][1];
}
}
else{
if( (p & ( 1ll << i )) ){
// printf( "i%d x%d\n" ,i ,p );
if( trie[root][0] )
ans1 = mo( ans1 + sum[trie[root][0]] );
root = trie[root][1];
}
else{
if( trie[root][1] )
ans1 = mo( ans1 + sum[trie[root][1]] );
root = trie[root][0];
}
}
if( !root ) break;
}
if( root ) ans1 = mo( ans1 + sum[root] );
return ans1;
}
void insert_( ll x , ll f ){
root = 0;
for( int i = 60 ; i >= 0 ; i -- ){
if( (x & ( 1ll << i ) )){
if( !trie[root][1] ){
trie[root][1] = ++ncnt;
sum[root] = mo( sum[root] + f );
root = ncnt;
}
else{
sum[root] = mo( sum[root] + f ) ;
root = trie[root][1];
}
}
else{
if( !trie[root][0] ){
trie[root][0] = ++ncnt;
sum[root] = mo( sum[root] + f );
root = ncnt;
}
else{
sum[root] = mo( sum[root] + f );
root = trie[root][0];
}
}
}
sum[root] = mo( sum[root] + f );
}
int main(){
scanf( "%d%lld" , &n , &X );
for( int i = 1 ; i <= n ; i ++ ) scanf( "%lld" , &a[i] );
sort( a + 1 , a + n + 1 );
for( int i = n ; i ; i -- ){
ll f = mo( find_( a[i] ) + 1 );
ans = mo( ans + f );
insert_( a[i] , f );
}
printf( "%lld\n" , ans );
}
第一次写字典树,代码很丑