此题巧妙的地方有两点,一点是利用线段树维护区间异或和的时候,直接判断区间长度的奇偶性,为奇数就异或给定值。另一点就是在删除的时候,巧妙的利用插入函数先插入再删除,这样一来删除就变得简单得多。
这道题对代码能力要求较高,笔者来来回回折腾了近四个小时才调出来,果然对代码能力的训练是一个永恒的话题。
代码如下
#include<bits/stdc++.h>
using namespace std;
#define lson k<<1
#define rson k<<1|1
const int maxn=1e5+5;
const int inf=0x3f3f3f3f;
map<int,int> mp;
set<pair<int,int>> st[maxn];
int cnt=0;
int n,m;
struct node
{
int siz;
int sum;
int lazy;
}tree[maxn<<2];
void work(int k,int v)
{
if(tree[k].siz&1)
tree[k].sum^=v;
tree[k].lazy^=v;
}
void pushup(int k)
{
tree[k].sum=tree[lson].sum^tree[rson].sum;
}
void pushdown(int k)
{
if(tree[k].lazy)
{
work(lson,tree[k].lazy);
work(rson,tree[k].lazy);
tree[k].lazy=0;
}
}
void build(int k,int l,int r)
{
tree[k].siz=r-l+1;
if(l==r)
return ;
int mid=l+r>>1;
build(lson,l,mid);
build(rson,mid+1,r);
}
void update(int k,int l,int r,int L,int R,int x)
{
if(L<=l&&r<=R)
{
work(k,x);
return ;
}
pushdown(k);
int mid=l+r>>1;
if(L<=mid)
update(lson,l,mid,L,R,x);
if(R>mid)
update(rson,mid+1,r,L,R,x);
pushup(k);
}
int query(int k,int l,int r,int L,int R)
{
if(L<=l&&r<=R)
{
return tree[k].sum;
}
pushdown(k);
int ans=0;
int mid=l+r>>1;
if(L<=mid)
ans^=query(lson,l,mid,L,R);
if(R>mid)
ans^=query(rson,mid+1,r,L,R);
return ans;
}
void add(int l,int r,int x)
{
set<pair<int,int>> &s=st[mp[x]];
auto it=s.lower_bound(make_pair(l,r));
if(it!=s.begin())
{
it--;
if(it->second>r)
return ;
if(it->second>=l)
{
l=min(l,it->first);
r=max(r,it->second);
update(1,1,n,it->first,it->second,x);
s.erase(it);
}
}
it=s.lower_bound(make_pair(l,r));
while((it!=s.end())&&((it->second>=l&&it->second<=r)||(it->first>=l&&it->first<=r)))
{
l=min(l,it->first);
r=max(r,it->second);
update(1,1,n,it->first,it->second,x);
s.erase(it++);
}
update(1,1,n,l,r,x);
s.insert(make_pair(l,r));
}
void del(int l,int r,int x)
{
if(!mp[x]) return ;
set<pair<int,int>> &s=st[mp[x]];
add(l,r,x);
update(1,1,n,l,r,x);
auto it=s.lower_bound(make_pair(l,r));
if(it==s.end()||it->first>l)
it--;
if(it->first<l)
s.insert(make_pair(it->first,l-1));
if(it->second>r)
s.insert(make_pair(r+1,it->second));
s.erase(it);
}
int main()
{
cin>>n>>m;
build(1,1,n);
while(m--)
{
int op,l,r,x;
cin>>op>>l>>r>>x;
if(op==1)
{
if(!mp[x])
mp[x]=++cnt;
add(l,r,x);
}
else if(op==2)
del(l,r,x);
else printf("%d\n",query(1,1,n,l,r));
}
}