题目链接: http://acm.hdu.edu.cn/showproblem.php?pid=6155
题意: 给你一个长度为n的01序列a, m个操作, 分两种, 一是区间[l,r]取反, 二是询问区间[l,r]的不同子序列个数。 ( n,m≤105 )
思路: 先考虑如何直接求不同子序列个数。 令f[i][j]表示考虑到第i位最后一位是j的不同子序列个数。 若a[i] == 0, 则转移为f[i][0] = f[i - 1][0] + f[i - 1][1] + 1, f[i][1] = f[i - 1][1], a[i] == 1时类似。 答案即为f[n][0] + f[n][1]。 观察这个递推式, 对于一段区间[l, r], 它最终的答案只与f[l-1][0], f[l-1][1]和一个常数有关, 将前两者看成未知数, 考虑用线段树维护每个区间的各个系数, 区间合并的时候推一下系数之间的转移即可, 类似于矩阵乘法。
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#define ls (x << 1)
#define rs (x << 1 | 1)
#define mid ((l + r) >> 1)
#define ll long long
using namespace std;
const int N = (int)1e5 + 10;
const int mo = (int)1e9 + 7;
struct Nod{
ll fac[2][3];
}nod[N << 2];
int n, m, a[N];
int bj[N << 2]; ll f[N][2];
Nod unit(Nod x, Nod y){
Nod ret;
ret.fac[0][0] = (x.fac[0][0] * y.fac[0][0] + x.fac[1][0] * y.fac[0][1]) % mo;
ret.fac[0][1] = (x.fac[0][1] * y.fac[0][0] + x.fac[1][1] * y.fac[0][1]) % mo;
ret.fac[0][2] = (x.fac[0][2] * y.fac[0][0] + x.fac[1][2] * y.fac[0][1] + y.fac[0][2]) % mo;
ret.fac[1][0] = (x.fac[0][0] * y.fac[1][0] + x.fac[1][0] * y.fac[1][1]) % mo;
ret.fac[1][1] = (x.fac[0][1] * y.fac[1][0] + x.fac[1][1] * y.fac[1][1]) % mo;
ret.fac[1][2] = (x.fac[0][2] * y.fac[1][0] + x.fac[1][2] * y.fac[1][1] + y.fac[1][2]) % mo;
return ret;
}
void flip(int x){
for (int i = 0; i < 3; i ++)
swap(nod[x].fac[0][i], nod[x].fac[1][i]);
for (int i = 0; i < 2; i ++)
swap(nod[x].fac[i][0], nod[x].fac[i][1]);
}
void pushdown(int x){
if (bj[x]){
bj[ls] ^= 1, bj[rs] ^= 1;
flip(ls), flip(rs);
bj[x] = 0;
}
}
void build(int x, int l, int r){
bj[x] = 0;
if (l == r){
memset(nod[x].fac, 0, sizeof(nod[x].fac));
if (a[l] == 0){
nod[x].fac[0][0] = nod[x].fac[0][1] = nod[x].fac[0][2] = 1;
nod[x].fac[1][1] = 1;
}
else{
nod[x].fac[1][0] = nod[x].fac[1][1] = nod[x].fac[1][2] = 1;
nod[x].fac[0][0] = 1;
}
return;
}
build(ls, l, mid); build(rs, mid + 1, r);
nod[x] = unit(nod[ls], nod[rs]);
}
void modf(int x, int l, int r, int L, int R){
if (l == L && r == R){
bj[x] ^= 1; flip(x); return;
}
pushdown(x);
if (R <= mid) modf(ls, l, mid, L, R);
else if (L > mid) modf(rs, mid + 1, r, L, R);
else modf(ls, l, mid, L, mid), modf(rs, mid + 1, r, mid + 1, R);
nod[x] = unit(nod[ls], nod[rs]);
}
Nod query(int x, int l, int r, int L, int R){
if (l == L && r == R) return nod[x];
pushdown(x);
if (R <= mid) return query(ls, l, mid, L, R);
else if (L > mid) return query(rs, mid + 1, r, L, R);
else return unit(query(ls, l, mid, L, mid), query(rs, mid + 1, r, mid + 1, R));
}
int main(){
int T; T = 0;
for (scanf("%d", &T); T --; ){
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i ++) scanf("%1d", a + i);
build(1, 1, n);
while (m --){
int opt, l, r;
scanf("%d %d %d", &opt, &l, &r);
if (opt == 1){
modf(1, 1, n, l, r);
}
else{
Nod ret = query(1, 1, n, l, r);
printf("%lld\n", (ret.fac[0][2] + ret.fac[1][2]) % mo);
}
}
}
return 0;
}