链接:题目
题目大意:给出n个整数.有三种操作:
1.询问[l r]的总和
2.[l r]里面所有数减去x&(-x)也就是去掉最低位
3.[l r]里面所有数加上2^k( 2^k<=x<= 2^k+1),也就是说加上一个二进制的最高位,比如1010+1000=10010,观察不免发现是最高位左移一位.
思路:
操作2就是去掉最低位,操作3就是最高位左移,那么我们就可以将一个数字拆成两部分 一部分是a1也就是最高位代表的数字,一部分是a2也就是剩下位代表的数字 (如:10110,那么a1=10000,a2=110) 我们可以写一个线段树去维护这两部分 而且容易知道a1这部分的维护是可以用lazy标记的,因为修改[l r]这一段相当于[l r]这个区间的所有a1都左移一位也就是乘2 区间乘法可以用lazy标记求和 其他的就和普通的写法一样了
#include<iostream>
#define LL long long
using namespace std;
const LL mod = 998244353;
const int N = 100005;
LL a1[N], a2[N];
int T, n, q;
LL lowbit(LL x) { return x & (-x); }
struct SegmentTree1 {
LL l, r, sum1, sum2, la, tg;
#define ls x<<1
#define rs x<<1|1
#define l(x) tree[x].l
#define r(x) tree[x].r
#define sum1(x) tree[x].sum1
#define sum2(x) tree[x].sum2
#define la(x) tree[x].la
#define tg(x) tree[x].tg
}tree[N << 2];
//向上是线段树合并
void pushup(int x) {
sum1(x) = (sum1(ls) + sum1(rs)) % mod;
sum2(x) = (sum2(ls) + sum2(rs)) % mod;
tg(x) = tg(ls) & tg(rs);//如果两个区间都是1的话就不继续递归下去了
}
void pushdown(int x) {
la(ls) = la(ls) * la(x) % mod;
la(rs) = la(rs) * la(x) % mod;
sum1(ls) = sum1(ls) * la(x) % mod;
sum1(rs) = sum1(rs) * la(x) % mod;
tg(ls) |= tg(x);
tg(rs) |= tg(x);
if (tg(ls))sum2(ls) = 0;
if (tg(rs))sum2(rs) = 0;
la(x) = 1;//标记
}
void build(int x,int l,int r) {
la(x) = 1, tg(x) = 0;
l(x)=l,r(x)=r;
if (l(x) == r(x)) {
sum1(x) = a1[l(x)], sum2(x) = a2[l(x)];
return;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(x);
}
LL query(int x, int l, int r) {
if (l <= l(x) && r >= r(x)) {
return (sum1(x) + sum2(x)) % mod;
}
pushdown(x);
int mid = (l(x) + r(x)) >> 1;
LL ans = 0;
if (l <= mid)ans += query(ls, l, r);
if (r > mid)ans += query(rs, l, r);
ans%=mod;
return ans;
}
void up1(int x, int l, int r) {
if (l(x) == r(x)) {
if (sum2(x)) {
sum2(x) -= lowbit(sum2(x));
}
else {
sum1(x) = 0;
tg(x) = 1;
}
return;
}
pushdown(x);
int mid = (l(x) + r(x)) >> 1;
if (l <= mid && !tg(ls))up1(ls, l, r);
if (r > mid && !tg(rs))up1(rs, l, r);
pushup(x);
}
void up2(int x, int l, int r) {
if (l <= l(x) && r >= r(x)) {
sum1(x) = sum1(x) * 2 % mod;
la(x) = la(x) * 2 % mod;
return;
}
pushdown(x);
int mid = (l(x) + r(x)) >> 1;
if (l <= mid)up2(ls, l, r);
if (r > mid)up2(rs, l, r);
pushup(x);
}
int main() {
int T;
scanf("%d", &T);
while (T--) {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
LL x;
scanf("%lld", &x);
for (int k = 30; k >= 0; k--) {
if ((1ll << k) <= x) {
a1[i] = (1ll << k);
a2[i] = x - a1[i];
break;
}
}
}
build(1,1,n);
scanf("%d", &q);
while (q--) {
int opt, l, r;
scanf("%d%d%d", &opt, &l, &r);
if (opt == 1) {
printf("%lld\n", query(1, l, r));
}
else if (opt == 2) {
up1(1, l, r);
}
else if (opt == 3) {
up2(1, l, r);
}
}
}
return 0;
}