题目
思路
首先可以拆位,问题转化为在 O ( n ) \mathcal O(n) O(n) 的复杂度内确定长度为 n n n 的 01 01 01 串,有两种要求:
- [ l , r ] [l,r] [l,r] 不可以全为 1 1 1 。即: [ l , r ] [l,r] [l,r] 至少有一个 0 0 0 。
- [ l , r ] [l,r] [l,r] 全部都为 1 1 1 。即: [ l , r ] [l,r] [l,r] 不能有一个 0 0 0 。
那么请问,在这道题里面, 0 0 0 和 1 1 1 谁起到关键作用呢?都考虑一下吧。
如果 1 1 1 被作为关键元素?
条件一该如何理解?额……差分数组的不等式?做不来。
条件二该如何理解? 1 1 1 覆盖整个区间。
如果 0 0 0 被作为关键元素?
条件一该如何理解?相邻的两个零之间不能夹住一个区间。
条件二该如何理解?零不能出现在该区间内。
结论: 0 0 0 是关键元素!
主要是因为, 0 0 0 只需要一个就能判断(像是立即 b r e a k \tt break break 的感觉),而 1 1 1 需要整个区间填完才能判断。
然后我们考虑对 0 0 0 进行 d p \tt dp dp 。用 f ( i ) f(i) f(i) 表示,填充前 i i i 个位置,且第 i i i 个位置是 0 0 0 。
对于条件二,将 i ∈ [ l , r ] i\in[l,r] i∈[l,r] 的 f ( i ) f(i) f(i) 置 0 0 0 即可。
对于条件一,考虑前一个 0 0 0 在 j j j 处,不能存在 j ≤ l p ≤ r p < i j\le l_p\le r_p<i j≤lp≤rp<i ——只需要求 l p l_p lp 的最大值即可。
这样一来,可转移的 d p dp dp 为区间,用前缀和即可。答案是 f ( n + 1 ) f(n+1) f(n+1) 。
代码
#include <cstdio>
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
inline int readint(){
int a = 0; char c = getchar(), f = 1;
for(; c<'0'||c>'9'; c=getchar())
if(c == '-') f = -f;
for(; '0'<=c&&c<='9'; c=getchar())
a = (a<<3)+(a<<1)+(c^48);
return a*f;
}
template < typename T >
void getMax(T&a,const T&b){if(a<b)a=b;}
template < typename T >
void getMin(T&a,const T&b){if(b<a)a=b;}
const int Mod = 998244353;
const int MaxN = 500005;
int dp[MaxN], n, k, m;
struct Node{
int l, r, val;
bool operator < (const Node &that) const {
return r < that.r;
}
void input(){
l = readint(), r = readint();
val = readint();
}
} node[MaxN];
void input(){
n = readint(), k = readint();
m = readint();
for(int i=1; i<=m; ++i)
node[i].input();
sort(node+1,node+m+1);
}
void solve(){
int ans = 1;
while(k --){
dp[0] = 1, dp[n+1] = -1;
int t = n+1, j = m;
for(int i=n; i; --i){
while(j && i <= node[j].r){
while(j && !(node[j].val&1))
-- j;
if(j && i <= node[j].r)
getMin(t,node[j--].l);
}
if(t <= i) dp[i] = 0;
else dp[i] = -1;
}
t = 0, j = 1;
for(int i=1; i<=n+1; ++i){
if(dp[i] != -1){
dp[i] = dp[i-1];
continue;
}
while(j <= m && node[j].r < i){
while(j <= m && (node[j].val&1))
++ j;
if(j <= m && node[j].r < i)
getMax(t,node[j++].l);
}
dp[i] = dp[i-1];
if(t) dp[i] -= dp[t-1];
dp[i] = (dp[i]+Mod)%Mod;
// 处理前缀和
dp[i] = (dp[i]+dp[i-1])%Mod;
}
dp[n+1] = dp[n+1]+Mod-dp[n];
ans = 1ll*ans*dp[n+1]%Mod;
for(int i=1; i<=m; ++i)
node[i].val >>= 1;
}
printf("%d\n",ans);
}
int main(){
input(), solve();
return 0;
}