分位维护线段树。
laz标记的理解。 区间恰好覆盖时打标记。与下面子树的无关。只有下放时需要考虑当前节点和子树。这时当前节点的laz标记还没有影响到叶子节点,只需要下放影响即可,这样当前节点的影响就与子树未下放的影响就融合在一起了。其实laz标记记录的就是修改而已,将当前节点的修改下放到子树中去,下放后当然要把修改清0。线段树每个节点的修改是一定相同的,所以只需要考虑修改的叠加影响即可。
#include <cstdio>
#define lc l,mid,x<<1
#define rc mid+1,r,x<<1|1
using namespace std;
typedef long long LL;
const int maxn = 100005;
const int base = 20;
int tree[4*maxn][base+1],laz0[4*maxn][base+1],laz1[4*maxn][base+1],laz2[4*maxn][base+1];
int a[maxn];
void cal( int* a,int x ){
for( int i = 0;i <= base;i++ ){
a[i] = x >> i &1;
}
}
void push_up( int x ){
for( int i = 0;i <= base;i++ ){
tree[x][i] = tree[x<<1][i] + tree[x<<1|1][i];
}
}
void push_down( int x,int l,int r ){
for( int i = 0;i <= base;i++ ){
if( laz0[x][i] ){
laz0[x<<1][i] = laz0[x<<1|1][i] = 1;
laz1[x<<1][i] = laz1[x<<1|1][i] = 0;
laz2[x<<1][i] = laz2[x<<1|1][i] = 0;
}else if( laz1[x][i] ){
laz1[x<<1][i] = laz1[x<<1|1][i] = 1;
laz0[x<<1][i] = laz0[x<<1|1][i] = 0;
laz2[x<<1][i] = laz2[x<<1|1][i] = 0;
}else if( laz2[x][i] ){
if( laz0[x<<1][i] || laz1[x<<1][i] ){
laz0[x<<1][i] ^= 1;laz1[x<<1][i] ^= 1;
laz2[x<<1][i] = 0;
}else{
laz2[x<<1][i] ^= 1;
}
if( laz0[x<<1|1][i] || laz1[x<<1|1][i] ){
laz0[x<<1|1][i] ^= 1;laz1[x<<1|1][i] ^= 1;
laz2[x<<1|1][i] = 0;
}else{
laz2[x<<1|1][i] ^= 1;
}
}
}
int mid= l+r >>1;
for( int i = 0;i <= base;i++ ){
if( laz0[x][i] ){
tree[x<<1][i] = tree[x<<1|1][i] = 0;
}else if( laz1[x][i] ){
tree[x<<1][i] = mid-l+1;
tree[x<<1|1][i] = r-mid;
}else if( laz2[x][i] ){
tree[x<<1][i] = mid-l+1-tree[x<<1][i];
tree[x<<1|1][i] = r-mid - tree[x<<1|1][i];
}
laz0[x][i] = laz1[x][i] = laz2[x][i] = 0;
}
}
void build( int l,int r,int x ){
if( l == r ){
cal( tree[x],a[l] );
return;
}
int mid = l+r >>1;
build(lc);
build(rc);
push_up(x);
}
void solve( int v,int x,int op ){
for( int i = 0;i <= base;i++ )
if( op == 2 ){
if(v>>i&1) {
if (laz0[x][i] || laz1[x][i]) {
laz0[x][i] ^= 1;
laz1[x][i] ^= 1;
laz2[x][i] = 0;
} else {
laz2[x][i] ^= 1;
}
}
}else if( op == 3 ){
if( v >>i&1 ){
laz0[x][i] = 0;laz1[x][i] = 1;
laz2[x][i] = 0;
}
}else if( op == 4 ){
if(!(v>>i&1)){
laz0[x][i] = 1;laz1[x][i] = 0;
laz2[x][i] = 0;
}
}
}
void calc2( int v,int x,int op,int sz ){
for( int i = 0;i <= base;i++ )
if( op == 2 ){
if( v >>i&1 ) {
tree[x][i] = sz - tree[x][i];
}
}else if( op == 3 ){
if( v >> i & 1 ) tree[x][i] = sz;
}else if( op == 4 ){
if( !(v>>i&1) ) tree[x][i] = 0;
}
}
void update( int left,int right ,int op,int v,int l,int r,int x ){
if( left <= l && right >= r ){
solve( v,x,op );
calc2( v,x,op,r-l+1 );
return;
}
push_down(x,l,r);
int mid = l+r >>1;
if( left <= mid ){
update( left,right,op,v,lc );
}
if( right > mid ){
update( left,right,op,v,rc );
}
push_up(x);
}
LL cal( int x ){
LL res = 0;
for( int i = 0;i <= base;i++ ){
res += tree[x][i] * (1LL<<i);
}
return res;
}
LL query( int left,int right,int l,int r,int x ){
if( left <= l && right >= r ){
LL res = cal(x);
return res;
}
push_down(x,l,r);
int mid = l+r >>1;
LL res = 0;
if( left <= mid )res += query(left,right,lc);
if( right > mid ) res += query(left,right,rc);
push_up(x);
return res;
}
int main(){
int n,m,op,l,r;
scanf("%d",&n);
for( int i = 1;i <= n;i++ ) scanf("%d",&a[i]);
scanf("%d",&m);
build( 1,n,1 );
int v;
for( int i = 1;i <= m;i++ ){
scanf("%d%d%d",&op,&l,&r);
if( op == 1 ){
LL res = query(l,r,1,n,1);
printf("%I64d\n",res);
}else if( op == 2 ){
scanf("%d",&v);
update( l,r,2,v,1,n,1 );
}else if( op == 3 ){
scanf("%d",&v);
update( l,r,3,v,1,n,1 );
}else{
scanf("%d",&v);
update( l,r,4,v,1,n,1 );
}
}
return 0;
}
隔壁队写的。学到了,首先他的laz状态定义的比我好,laz存的值代表当前节点的修改。
还有一点就是 (#define)ls = l,mid,x<<1 (int)lc = x<<1
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <iostream>
#define ls u << 1, l, mid
#define rs u << 1 | 1, mid + 1, r
#define maxn 400021
#define LL long long
using namespace std;
int lz[maxn][21], sum[maxn][21][2], ans[21], n, m, a[maxn], tmp[21];
void push_up(int u){
for(int i = 0; i <= 20; i++){
sum[u][i][0] = sum[u << 1][i][0] + sum[u << 1 | 1][i][0];
sum[u][i][1] = sum[u << 1][i][1] + sum[u << 1 | 1][i][1];
}
}
void push_down(int u){
int lc = u << 1, rc = u << 1 | 1;
for(int i = 0; i<= 20; i++){
if(!lz[u][i])continue;
if(lz[u][i] == 1){
swap(sum[lc][i][0], sum[lc][i][1]);
swap(sum[rc][i][0], sum[rc][i][1]);
if(lz[lc][i] == 1 || lz[lc][i] == 0)lz[lc][i] = !lz[lc][i];
else if(lz[lc][i] == 2)lz[lc][i] = 3;
else if(lz[lc][i] == 3)lz[lc][i] = 2;
if(lz[rc][i] == 1 || lz[rc][i] == 0)lz[rc][i] = !lz[rc][i];
else if(lz[rc][i] == 2)lz[rc][i] = 3;
else if(lz[rc][i] == 3)lz[rc][i] = 2;
}else if(lz[u][i] == 2){
sum[lc][i][0] += sum[lc][i][1];sum[lc][i][1]=0;
sum[rc][i][0] += sum[rc][i][1];sum[rc][i][1]=0;
lz[lc][i] = lz[rc][i] = 2;
}else if(lz[u][i] == 3){
sum[lc][i][1] += sum[lc][i][0];sum[lc][i][0]=0;
sum[rc][i][1] += sum[rc][i][0];sum[rc][i][0]=0;
lz[lc][i] = lz[rc][i] = 3;
}
lz[u][i] = 0;
}
}
void update(int u, int l, int r, int x, int y, int pos){
if(l == x && r == y){
for(int i = 0; i <= 20; ++i){
if(pos == 1){
if(!tmp[i])continue;
swap(sum[u][i][0], sum[u][i][1]);
if(lz[u][i] == 1 || lz[u][i] == 0)lz[u][i] = !lz[u][i];
else if(lz[u][i] == 2)lz[u][i] = 3;
else if(lz[u][i] == 3)lz[u][i] = 2;
}else if(pos == 2){
if(tmp[i] == 1)continue;
sum[u][i][0] += sum[u][i][1];sum[u][i][1]=0;
lz[u][i] = 2;
}else{
if(tmp[i] == 0)continue;
sum[u][i][1] += sum[u][i][0];sum[u][i][0]=0;
lz[u][i] = 3;
}
}
return;
}int mid = l + r >> 1;
push_down(u);
if(x > mid)update(rs, x, y, pos);
else if(y <= mid)update(ls, x, y, pos);
else update(ls, x, mid, pos), update(rs, mid + 1, y, pos);
push_up(u);
}
void query(int u, int l, int r, int x, int y){
if(l == x && r == y){
for(int i = 0; i <= 20; i++){
ans[i] += sum[u][i][1];
}
return;
}int mid = l + r >> 1;
push_down(u);
if(x > mid)query(rs, x, y);
else if(y <= mid)query(ls, x, y);
else query(ls, x, mid), query(rs, mid + 1, y);
}
void build(int u, int l, int r){
if(l == r){
for(int i = 0; i <= 20; i++){
sum[u][i][(a[l] >> i) & 1] = 1;
}
return;
}int mid = l + r >> 1;
build(ls);build(rs);
push_up(u);
}
int main(){
scanf("%d", &n);
for(int i = 1; i <= n; i++)scanf("%d", a + i);
build(1, 1, n);
scanf("%d", &m);
int pos, x, y, z;
while(m--){
scanf("%d%d%d", &pos, &x, &y);
if(pos == 1){
memset(ans, 0, sizeof(ans));
query(1, 1, n, x, y);
LL all = 0;
for(int i = 0; i <= 20; i++){
all += (LL)ans[i] * (1ll << i);
}
printf("%lld\n", all);
}else{
scanf("%d", &z);
for(int i = 0; i <= 20; i++)tmp[i] = (z >> i) & 1;
if(pos == 2)pos = 1;
else if(pos == 3)pos = 3;
else if(pos == 4)pos = 2;
update(1, 1, n, x, y, pos);
}
}
return 0;
}