Subsequence Count
Time Limit: 10000/5000 MS (Java/Others) Memory Limit: 256000/256000 K (Java/Others)Total Submission(s): 782 Accepted Submission(s): 288
There are two types of queries:
1. Flipping the bits (i.e., changing all 1 to 0 and 0 to 1) between l and r (inclusive).
2. Counting the number of distinct subsequences in the substring S[l,...,r] .
For each test, the first line contains two integers N and Q .
The second line contains the string S .
Then Q lines follow, each with three integers type , l and r , denoting the queries.
1≤T≤5
1≤N,Q≤105
S[i]∈{0,1},∀1≤i≤N
type∈{1,2}
1≤l≤r≤N
2 4 4 1010 2 1 4 2 2 4 1 2 3 2 1 4 4 4 0000 1 1 2 1 2 3 1 3 4 2 1 4
11 6 8 10
给一个01串,要求完成两种操作:
1.区间修改,把一段区间的数值取反
2.区间查询,查询某一段的不同的01串子序列数量
如果我们用dp[i][0]和dp[i][1]表示dp到第i位,以0、1结尾的不同的子序列数量,则很容易得到DP公式:
dp[i][0]=dp[i-1][0]+dp[i-1][1]+1,dp[i][1]=dp[i-1][1] 第i位为0
dp[i][1]=dp[i-1][0]+dp[i-1][1]+1,dp[i][0]=dp[i-1][0] 第i位为1
那么把dp[i][0],dp[i][1],1当做矩阵的一列,可以很容易地把DP过程表示成矩阵乘法的形式。
dp[i][0] 1 1 1 dp[i-1][0]
dp[i][1] = 0 1 0 * dp[i-1][1]
dp[i][0] 1 0 0 dp[i-1][0]
dp[i][1] = 1 1 1 * dp[i-1][1]
那么,我们可以用线段树维护转移矩阵,查询的时候查询转移矩阵就好了。
而修改时,反转一段相当于某一段区间异或1,我们可以用异或标记在线段树上标明。而此时的转移矩阵,恰好把第一列和第二列交换、第一行和第二行交换,就可以了。(草稿纸推理演算得到)
矩阵乘法的常数比较大,我们发现转移矩阵最后一行只能是0,0,1,可以由此优化一下,减少时间复杂度。
最后跑了2873ms
#include <cstdio>
#include <string.h>
#include <algorithm>
#define mem0(a) memset(a,0,sizeof(a))
#define meminf(a) memset(a,0x3f,sizeof(a))
#define size 3
using namespace std;
typedef long long ll;
typedef long double ld;
typedef double db;
const int maxn=100005,inf=0x3f3f3f3f;
const ll llinf=0x3f3f3f3f3f3f3f3f,mod=1e9+7;
char s[maxn];
int num;
struct Matrix {
ll a[size][size];
};
Matrix m0=(Matrix){1,1,1,0,1,0,0,0,1};
Matrix m1=(Matrix){1,0,0,1,1,1,0,0,1};
void print(Matrix v) {
int i,j;
for (i=0;i<size;i++) {
for (j=0;j<size;j++) {
printf("%lld ",v.a[i][j]);
}
printf("\n");
}
printf("\n");
}
Matrix operator*(const Matrix &x,const Matrix &y) {
int i,j,k;
Matrix ans;
for (i=0;i<size-1;i++) {
for (j=0;j<size;j++) {
ans.a[i][j]=0;
for (k=0;k<size;k++) {
ans.a[i][j]+=x.a[i][k]*y.a[k][j];
ans.a[i][j]%=mod;
}
}
}
ans.a[2][0]=ans.a[2][1]=0;ans.a[2][2]=1;
return ans;
}
void flip(Matrix &x) {
for (int i=0;i<size;i++) swap(x.a[i][0],x.a[i][1]);
for (int i=0;i<size;i++) swap(x.a[0][i],x.a[1][i]);
}
struct Tree {
int lc,rc,l,r,isxor;
Matrix sum;
};
Tree tree[4*maxn];
void pushdown(int now) {
if (tree[now].isxor==0) return;
int l=tree[now].lc,r=tree[now].rc;
tree[l].isxor=tree[l].isxor^tree[now].isxor;
tree[r].isxor=tree[r].isxor^tree[now].isxor;
flip(tree[l].sum);flip(tree[r].sum);
tree[now].isxor=0;
}
void build(int now,int l,int r) {
tree[now].l=l;
tree[now].r=r;
tree[now].isxor=0;
if (l!=r) {
num++;
tree[now].lc=num;
build(num,l,(l+r)/2);
num++;
tree[now].rc=num;
build(num,(l+r)/2+1,r);
tree[now].sum=tree[tree[now].rc].sum*tree[tree[now].lc].sum;
} else {
if (s[l]=='0') tree[now].sum=m0;
else tree[now].sum=m1;
}
}
void update (int now,int l,int r) {
if (tree[now].l>=l&&tree[now].r<=r) {
tree[now].isxor^=1;
flip(tree[now].sum);
} else {
pushdown(now);
if (l<=(tree[now].l+tree[now].r)/2)
update(tree[now].lc,l,r);
if (r>(tree[now].l+tree[now].r)/2)
update(tree[now].rc,l,r);
tree[now].sum=tree[tree[now].rc].sum*tree[tree[now].lc].sum;
}
}
Matrix findsum(int now,int l,int r) {
// cout << now << ' ' << tree[now].l << ' ' << tree[now].r << ' ' << tree[now].tag << endl;
if (tree[now].l>=l&&tree[now].r<=r) {
return tree[now].sum;
} else {
pushdown(now);
if (r>(tree[now].l+tree[now].r)/2) {
Matrix f;
f=findsum(tree[now].rc,l,r);
if (l<=(tree[now].l+tree[now].r)/2)
f=f*findsum(tree[now].lc,l,r);
return f;
} else return findsum(tree[now].lc,l,r);
}
}
bool findval(int now,int pos) {
// cout << now << ' ' << tree[now].l << ' ' << tree[now].r << ' ' << tree[now].tag << endl;
if (tree[now].l>=pos&&tree[now].r<=pos) {
return tree[now].isxor^(s[pos]-'0');
} else {
pushdown(now);
if (pos<=(tree[now].l+tree[now].r)/2)
return findval(tree[now].lc,pos);
if (pos>(tree[now].l+tree[now].r)/2)
return findval(tree[now].rc,pos);
}
}
int main() {
int cas;
scanf("%d",&cas);
while (cas--) {
int n,m,i,j,l,r,t;
scanf("%d%d",&n,&m);
scanf("%s",s+1);
num=1;
build(1,1,n);
for (i=1;i<=m;i++) {
scanf("%d%d%d",&t,&l,&r);
if (t==1) update(1,l,r); else {
if (l==r) {
printf("1\n");
continue;
}
Matrix v,t;
if (findval(1,l)) t=(Matrix){0,0,0,1,0,0,1,0,0};
else t=(Matrix){1,0,0,0,0,0,1,0,0};
v=findsum(1,l+1,r);
// print(v);
v=v*t;
ll q=(v.a[0][0]+v.a[1][0])%mod;
printf("%lld\n",q);
}
}
}
return 0;
}