【题目链接】
【思路要点】
- 考虑计算总区间数量-不合法区间的数量。
- 不合法区间分为两种:
- 1、区间和为1的区间。
- 2、区间和为大于1的奇数,且存在不足两个0。
- 分别用线段树统计上述区间个数即可。
- 时间复杂度\(O(QLogN)\)。
【代码】
#include<bits/stdc++.h> using namespace std; const int MAXN = 100005; const int MAXP = 200005; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } struct info { long long sum; int l, r, nl, nr; int len, zl, zr, znl, znr; }; info operator + (info a, info b) { info ans; ans.sum = a.sum + b.sum; ans.len = a.len + b.len; ans.sum += 1ll * b.znl * a.zr; ans.sum += 1ll * b.zl * a.znr; int tmp[2], tnp[2]; tmp[0] = a.r / 2; tmp[1] = a.r - tmp[0]; tnp[0] = b.l / 2; tnp[1] = b.l - tnp[0]; ans.sum += 1ll * tmp[0] * tnp[1]; ans.sum += 1ll * tmp[1] * tnp[0]; if (b.nl) { tmp[0] = a.r / 2; tmp[1] = a.r - tmp[0]; tnp[0] = (b.nl + 1) / 2; tnp[1] = b.nl - tnp[0]; if (b.l & 1) swap(tnp[0], tnp[1]); ans.sum += 1ll * tmp[0] * tnp[1]; ans.sum += 1ll * tmp[1] * tnp[0]; if (b.l == 0 && a.r != 0) ans.sum--; } if (a.nr) { tmp[0] = (a.nr + 1) / 2; tmp[1] = a.nr - tmp[0]; tnp[0] = b.l / 2; tnp[1] = b.l - tnp[0]; if (a.r & 1) swap(tmp[0], tmp[1]); ans.sum += 1ll * tmp[0] * tnp[1]; ans.sum += 1ll * tmp[1] * tnp[0]; if (a.r == 0 && b.l != 0) ans.sum--; } if (a.nl == 0) { ans.l = a.l + b.l; ans.nl = b.nl; } else { ans.l = a.l; if (a.l + a.nl == a.len) ans.nl = a.nl + b.l; else ans.nl = a.nl; } if (b.nr == 0) { ans.r = b.r + a.r; ans.nr = a.nr; } else { ans.r = b.r; if (b.r + b.nr == b.len) ans.nr = b.nr + a.r; else ans.nr = b.nr; } if (a.znl == 0) { ans.zl = a.zl + b.zl; ans.znl = b.znl; } else { ans.zl = a.zl; if (a.zl + a.znl == a.len) ans.znl = a.znl + b.zl; else ans.znl = a.znl; } if (b.znr == 0) { ans.zr = b.zr + a.zr; ans.znr = a.znr; } else { ans.zr = b.zr; if (b.zr + b.znr == b.len) ans.znr = b.znr + a.zr; else ans.znr = b.znr; } return ans; } struct SegmentTree { struct Node { int lc, rc; info ans; } a[MAXP]; int root, size, n; void build(int &root, int l, int r) { root = ++size; if (l == r) { int x; read(x); a[root].ans.sum = x; a[root].ans.len = 1; a[root].ans.l = a[root].ans.r = x; a[root].ans.nl = a[root].ans.nr = !x; a[root].ans.zl = a[root].ans.zr = !x; a[root].ans.znl = a[root].ans.znr = x; return; } int mid = (l + r) / 2; build(a[root].lc, l, mid); build(a[root].rc, mid + 1, r); a[root].ans = a[a[root].lc].ans + a[a[root].rc].ans; } void init(int x) { n = x; root = size; build(root, 1, n); } void modify(int root, int l, int r, int pos) { if (l == r) { int x = !a[root].ans.sum; a[root].ans.sum = x; a[root].ans.len = 1; a[root].ans.l = a[root].ans.r = x; a[root].ans.nl = a[root].ans.nr = !x; a[root].ans.zl = a[root].ans.zr = !x; a[root].ans.znl = a[root].ans.znr = x; return; } int mid = (l + r) / 2; if (mid >= pos) modify(a[root].lc, l, mid, pos); else modify(a[root].rc, mid + 1, r, pos); a[root].ans = a[a[root].lc].ans + a[a[root].rc].ans; } void modify(int pos) { modify(root, 1, n, pos); } info query(int root, int l, int r, int ql, int qr) { if (l == ql && r == qr) return a[root].ans; int mid = (l + r) / 2; if (mid >= qr) return query(a[root].lc, l, mid, ql, qr); else if (mid + 1 <= ql) return query(a[root].rc, mid + 1, r, ql, qr); else return query(a[root].lc, l, mid, ql, mid) + query(a[root].rc, mid + 1, r, mid + 1, qr); } long long query(int l, int r) { info ans = query(root, 1, n, l, r); return ans.sum; } } ST; int n, m; int main() { read(n); ST.init(n); read(m); for (int i = 1; i <= m; i++) { int opt; read(opt); if (opt == 1) { int x; read(x); ST.modify(x); } else { int l, r; read(l), read(r); int len = r - l + 1; long long ans = (len + 1ll) * len / 2; ans -= ST.query(l, r); writeln(ans); } } return 0; }