思路:线段树维护当前区间二进制每一位一共有多少个1,和当前区间的所有元素的平方和。
AC代码:
#include <bits/stdc++.h>
//#define int long long
using namespace std;
const int N=4e5+5,mod=998244353;
typedef long long ll;
int n,m;
ll w[N];
struct node{
int l,r;
ll sum;
int x[30];
}tr[N*4];
void pushup(int u){
tr[u].sum=(tr[u<<1].sum+tr[u<<1|1].sum)%mod;
for(int i=0;i<=25;i++)tr[u].x[i]=tr[u<<1].x[i]+tr[u<<1|1].x[i];
}
void build(int u,int l,int r){
if(l==r){
tr[u]={l,r,w[l]*w[l]%mod};
for(int i=0;i<=25;i++)tr[u].x[i]=w[l]>>i&1;
return;
}
tr[u]={l,r};
int mid=l+r>>1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
pushup(u);
}
void modify(int u,int l,int r,int x){
if(tr[u].l==tr[u].r){
ll sum=0;
for(int i=0;i<=25;i++)
if(tr[u].x[i]&&(x>>i&1))sum|=1LL<<i;
else tr[u].x[i]=0;
tr[u].sum=sum*sum%mod;
return;
}
if(tr[u].l>=l&&tr[u].r<=r){
int flag=1;
for(int i=0;i<=25;i++)
if(tr[u].x[i]&&(x>>i&1)==0)flag=0;
if(flag)return;
}
int mid=tr[u].l+tr[u].r>>1;
if(l<=mid) modify(u<<1,l,r,x);
if(r>mid) modify(u<<1|1,l,r,x);
pushup(u);
}
ll query(int u,int l,int r){
if(tr[u].l>=l&&tr[u].r<=r) return tr[u].sum;
int mid=tr[u].l+tr[u].r>>1;
int res=0;
if(l<=mid) res=query(u<<1,l,r)%mod;
if(r>mid) res=(res+query(u<<1|1,l,r))%mod;
return res;
}
main(){
scanf("%d",&n);
for(int i=1;i<=n;i++)scanf("%lld",&w[i]);
build(1,1,n);
scanf("%d",&m);
while(m--){
int t,l,r,x;
scanf("%d%d%d",&t,&l,&r);
if(t==1){
scanf("%d",&x);
modify(1,l,r,x);
}
else printf("%lld\n",query(1,l,r));
}
}