解析:
线段树维护区间矩阵乘法。每个节点都要维护矩阵
因为本题有两个公式 A=A+B; B=A+B
矩阵A:
[
A
B
]
\begin{bmatrix}A &B\end{bmatrix}
[AB] *
[
1
0
1
1
]
\begin{bmatrix}1 & 0\\ 1 & 1\\\end{bmatrix}
[1101] =
[
A
+
B
B
]
\begin{bmatrix}A+B & B \end{bmatrix}
[A+BB]
矩阵B: [ A B ] \begin{bmatrix}A &B\end{bmatrix} [AB] * [ 1 1 0 1 ] \begin{bmatrix}1 & 1\\ 0 & 1\\\end{bmatrix} [1011] = [ A A + B ] \begin{bmatrix}A & A+B \end{bmatrix} [AA+B]
对于操作1 因为要交换A和B 所以我们只要替换A与B的矩阵 然后下传懒惰标记。
对于操作2 因为要修改A和B的值 、
我们把值分开来写,先处理矩阵乘,最后再乘上A和B的值就是矩阵[a,b]
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10;
const int MOD=1e9+7;
int n,m,l,r;
char s[N];
ll x,y;
struct lxw
{
ll m[3][3];
};
struct node
{
int l,r,lazy;
lxw ans,rans;
}tr[N<<2];
lxw mult(lxw a,lxw b,int n)
{
lxw tmp;
memset(tmp.m,0,sizeof tmp.m);
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
for(int k=0;k<n;k++){
tmp.m[i][j]=(tmp.m[i][j]+a.m[i][k]*b.m[k][j])%MOD;
}
return tmp;
}
void push_down(int u)
{
if(tr[u].lazy)
{
swap(tr[u<<1].ans,tr[u<<1].rans);
swap(tr[u<<1|1].ans,tr[u<<1|1].rans);
tr[u<<1].lazy^=1;
tr[u<<1|1].lazy^=1;
tr[u].lazy=0;
}
}
void push_up(int u)
{
tr[u].ans=mult(tr[u<<1].ans,tr[u<<1|1].ans,2);
tr[u].rans=mult(tr[u<<1].rans,tr[u<<1|1].rans,2);
}
void build(int u,int l,int r)
{
tr[u].l=l;tr[u].r=r;tr[u].lazy=0;
if(l==r){
if(s[l]=='A')
{
tr[u].ans.m[0][0]=1;tr[u].ans.m[0][1]=0;
tr[u].ans.m[1][0]=1;tr[u].ans.m[1][1]=1;
tr[u].rans.m[0][0]=1;tr[u].rans.m[0][1]=1;
tr[u].rans.m[1][0]=0;tr[u].rans.m[1][1]=1;
}
else if(s[l]=='B')
{
tr[u].ans.m[0][0]=1;tr[u].ans.m[0][1]=1;
tr[u].ans.m[1][0]=0;tr[u].ans.m[1][1]=1;
tr[u].rans.m[0][0]=1;tr[u].rans.m[0][1]=0;
tr[u].rans.m[1][0]=1;tr[u].rans.m[1][1]=1;
}
return ;
}
int mid=l+r>>1;
build(u<<1,l,mid);build(u<<1|1,mid+1,r);
push_up(u);
}
lxw query(int u,int l,int r)
{
lxw res;
if(l<=tr[u].l&&tr[u].r<=r)
{
return tr[u].ans;
}
push_down(u);
int mid=(tr[u].l+tr[u].r)>>1;
if(r<=mid ) res=query(u<<1,l,r);
else if(l>mid) res=query(u<<1|1,l,r);
else
{
res=query(u<<1,l,mid);
res=mult(res,query(u<<1|1,mid+1,r),2);
}
push_up(u);
return res;
}
void update(int u,int l,int r )
{
if(l<=tr[u].l&&tr[u].r<=r){
tr[u].lazy^=1;
swap(tr[u].ans,tr[u].rans);
return ;
}
push_down(u);
int mid=(tr[u].l+tr[u].r)>>1;
if(r<=mid) update(u<<1,l,r);
else if(l>mid) update(u<<1|1,l,r);
else
{
update(u<<1,l,mid);
update(u<<1|1,mid+1,r);
}
push_up(u);
}
int main()
{
scanf("%d %d",&n,&m);
scanf("%s",(s+1));
build(1,1,n);
while(m--)
{
int op;
scanf("%d",&op);
if(op==1)
{
scanf("%d %d",&l,&r);
update(1,l,r);
}
else if(op==2)
{
scanf("%d %d %lld %lld",&l,&r,&x,&y);
lxw c;
c.m[0][0]=x;c.m[0][1]=y;
lxw tmp=query(1,l,r);
c=mult(c,tmp,2);
printf("%lld %lld\n",c.m[0][0]%MOD,c.m[0][1]%MOD);
}
}
}