题意
有 n n n 个位置,现在要往上面填数,给定 k k k,每个数小于 2 k 2^k 2k。现在有 m m m 个限制条件,每个限制条件给定 l i , r i , x i l_i,r_i,x_i li,ri,xi,要求满足 a [ l i ] & a [ l i + 1 ] & … & a [ r i ] = x i a[l_i] \& a[l_i + 1] \& \dots \& a[r_i] = x_i a[li]&a[li+1]&…&a[ri]=xi。求合法的填数方案数。
分析
因为
&
\&
& 运算是每位独立的,所以这题填的数每一位也是独立的,我们可以一位一位考虑。
那么限制条件转化为以下两种情况:
- 整个区间都填 1 1 1
- 整个区间中存在至少一个 0 0 0
那么我们只用考虑哪些位置填
0
0
0 即可。
设
f
i
,
j
f_{i,j}
fi,j 表示考虑到第
i
i
i 位,最后一个填
0
0
0 的位置为
j
j
j 的方案数。
设
g
i
,
j
=
∑
j
=
0
i
f
i
,
j
g_{i,j}=\sum\limits_{j=0}^{i}f_{i,j}
gi,j=j=0∑ifi,j
转移分两种情况:
- 不在第
i
i
i 位填
0
0
0
f i , j = f i − 1 , j f_{i,j}=f_{i-1,j} fi,j=fi−1,j - 在第
i
i
i 位填
0
0
0
f i , i = ∑ j = 0 i − 1 f i − 1 , j = g i − 1 , i − 1 f_{i,i}=\sum\limits_{j=0}^{i-1}f_{i-1,j}=g_{i-1,i-1} fi,i=j=0∑i−1fi−1,j=gi−1,i−1
看起来做完了,其实还没有!
如果
i
i
i 是某个存在
0
0
0 的限制的右端点,如图:
这个图是多个以
i
i
i 为右顶点的存在
0
0
0 的区间,
l
m
a
x
lmax
lmax 是最大的左端点。
那么最近一个
0
0
0 的位置肯定是在
[
l
m
a
x
,
i
]
[lmax,i]
[lmax,i] 中的。因此,
f
0
,
l
m
a
x
−
1
f_{0,lmax-1}
f0,lmax−1 的值都应该被清
0
0
0。
到了这里,我们得到了一个
O
(
k
n
2
)
O(kn^2)
O(kn2) 的做法:
每次暴力更新
f
i
,
j
f_{i,j}
fi,j,且每次暴力清
0
0
0
考虑优化。
看一下转移方程,设目前来到第
i
i
i 位,如果用滚动数组,转移方程将变为:
- 不在第
i
i
i 位填
0
0
0
f j = f j f_j=f_{j} fj=fj - 在第
i
i
i 位填
0
0
0
f i = g i − 1 f_i=g_{i-1} fi=gi−1
那么转移变成
O
(
1
)
O(1)
O(1) 的了。
考虑清
0
0
0 操作,我们维护上一次清
0
0
0 的点为
p
p
p,这次清
0
0
0,相当于将
[
p
,
l
m
a
x
−
1
]
[p,lmax-1]
[p,lmax−1] 清
0
0
0,然后将
p
p
p 移动到
l
m
a
x
lmax
lmax。这样,每个点只会被清
0
0
0 一次。
总的复杂度为
O
(
k
n
+
k
m
)
O(kn+km)
O(kn+km)。
代码如下
#include <bits/stdc++.h>
#define N 500005
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
const int mod = 998244353;
LL z = 1;
int read(){
int x, f = 1;
char ch;
while(ch = getchar(), ch < '0' || ch > '9') if(ch == '-') f = -1;
x = ch - '0';
while(ch = getchar(), ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - 48;
return x * f;
}
int cnt[N], mx[N], f[N], g[N];
struct node{
int l, r, x;
}d[N];
int main(){
int t, i, j, k, n, m, l, r, ans = 1, p;
n = read(); k = read(); m = read();
for(i = 1; i <= m; i++) d[i].l = read(), d[i].r = read(), d[i].x = read();
for(t = 0; t < k; t++){
for(i = 1; i <= n + 1; i++) f[i] = g[i] = mx[i] = cnt[i] = 0;
for(i = 1; i <= m; i++){
l = d[i].l; r = d[i].r;
if(1 << t & d[i].x) cnt[l]++, cnt[r + 1]--;
else mx[r] = max(mx[r], l);
}
for(i = 1; i <= n; i++) cnt[i] += cnt[i - 1];
f[0] = g[0] = 1;
p = 0;
for(i = 1; i <= n; i++){
g[i] = g[i - 1];
if(!cnt[i]){
f[i] = g[i - 1];
g[i] = (g[i] + f[i]) % mod;
}
while(p < mx[i]) g[i] = (g[i] - f[p]) % mod, p++;
}
ans = z * ans * g[n] % mod;
}
ans = (ans % mod + mod) % mod;
printf("%d", ans);
return 0;
}