//线段树
个人总结:想清楚怎么更新和下传
查询和更新采用同样的想法,在这里推荐用结构体写
这样代码少又能保证正确;(具体可以看GSS1中的两份代码比较);
题目来源:
POJ 2777
题目描述:
解题思路:
用线段树来维护一个区间有哪几种颜色
考虑到颜色总数只有 30 个,可以用二进制位来表示颜色
每个区间维护一个 mask,若 mask 的第 i 位为 1,就表示
这个区间里有第 i 种颜色。
查询的时候得到这个区间里颜色的二进制串,统计有几个一
即可。
#include<bits/stdc++.h>
using namespace std;
#define ls (x<<1)
#define rs (x<<1|1)
const int N = 1e5 + 10;
int mask[N << 2], tag[N << 2];
void update(int x) {
mask[x] = mask[ls] | mask[rs];
}
void build(int l, int r, int x) {
if (l == r) {
mask[x] = 1;
return;
}
int mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
update(x);
}
void down(int l, int r, int x) {
if (tag[x] != -1) {
tag[ls] = tag[rs] = tag[x];
mask[ls] = mask[rs] = mask[x];
tag[x] = -1;
}
}
void change(int A, int B, int t, int l, int r, int x) {
if (A <= l && r <= B) {
mask[x] = 1 << t;
tag[x] = t;
return;
}
down(l, r, x);
int mid = (l + r) >> 1;
if (A <= mid) change(A, B, t, l, mid, ls);
if (mid < B) change(A, B, t, mid + 1, r, rs);
update(x);
}
int query(int A, int B, int l, int r, int x) {
if (A <= l && r <= B) return mask[x];
down(l, r, x);
int mid = (l + r) >> 1, ret = 0;
if (A <= mid) ret |= query(A, B, l, mid, ls);
if (mid < B) ret |= query(A, B, mid + 1, r, rs);
return ret;
}
int main() {
int n, m, o;
while (scanf("%d%d%d", &n, &m, &o) == 3) {
memset(tag, -1, sizeof(tag));
build(1, n, 1);
char q[2];
int l, r, x;
while (o--) {
//test(1,n,1);
scanf("%s%d%d", q, &l, &r);
if (l > r)swap(l, r);
if (q[0] == 'C') {
scanf("%d", &x);
x--;
change(l, r, x, 1, n, 1);
}
else {
int ans = query(l, r, 1, n, 1);
int cnt = 0;
while (ans) {
if (ans & 1)cnt++;
ans = ans >> 1;
}
printf("%d\n", cnt);
}
}
}
return 0;
}
题目来源:
SP1043 GSS1 - Can you answer these queries I
题目描述:
你有一个长度为 n 的序列 A[1], A[2], …, A[N].
询问:
Query(x, y) = max { A[i] + … + A[j]; x <= i <= j <= y}
给出 M 组 (x, y),请给出 M 次询问的答案。
|A[i]| <= 15007, 1 <= N,M <= 50000
解题思路:
考虑用线段树来维护每个区间的答案。
假设 smax[x] 表示区间 x 的答案。
那么 smax[x] 如何由 smax[ls] 和 smax[rs] 合并得来呢?
不能直接合并。
我们还需要记录每个区间的前缀和最大值 lmax[x] 和后缀
和最大值 rmax[x]。
此时: smax[x] = max(max(smax[ls], smax[rs]), rmax[ls]
- lmax[rs])
- 为了更新 lmax 和 rmax,我们还需要记录每个区间的区间
和 sum。
这样就有:
lmax[x] = max(lmax[ls], sum[ls] + lmax[rs])
rmax[x] = max(rmax[rs], sum[rs] + rmax[ls])
SPOJ GSS1 Can you answer these qu
参考代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 5e4 + 10;
const int inf = 0x3f3f3f3f;
typedef long long ll;
#define ls (x<<1)
#define rs (x<<1|1)
ll lmax[N << 2], rmax[N << 2], smax[N << 2],sum[N<<2],a[N];
int n, m;
void upd(ll x) {
smax[x] = max(max(smax[ls], smax[rs]), rmax[ls] + lmax[rs]);
rmax[x] = max(rmax[rs], sum[rs] + rmax[ls]);
lmax[x] = max(lmax[ls], sum[ls] + lmax[rs]);
sum[x] = sum[ls] + sum[rs];
}
void build(ll l, ll r, ll x) {
if (l == r) {
sum[x]=smax[x] = lmax[x] = rmax[x] = a[l];
return;
}
ll mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
upd(x);
}
void ask(ll A, ll B, ll l, ll r, ll x, ll& S, ll& L, ll& R,ll&SUM) {
if (A <= l && r <= B) {
S = smax[x];
L = lmax[x];
R = rmax[x];
SUM = sum[x];
return;
}
ll mid = (l + r) >> 1;
ll SSl, LLl, RRl,SUMl;//左孩子
ll SSr, LLr, RRr,SUMr;//右孩子
SSl = LLl = RRl = SSr = LLr = RRr = -inf;
SUMl = SUMr = 0;
if (A <= mid)ask(A, B, l, mid, ls, SSl, LLl, RRl,SUMl);
if (mid < B)ask(A, B, mid + 1, r, rs, SSr, LLr, RRr,SUMr);
L = max(LLl, SUMl + LLr);
R = max(RRr, SUMr + RRl);
S = max(max(SSl, SSr), RRl + LLr);
SUM = SUMl + SUMr;
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++)scanf("%lld", &a[i]);
build(1, n, 1);
scanf("%d", &m);
while (m--) {
int l, r;
scanf("%d%d", &l, &r);
if (l > r)swap(l, r);
ll SS, LL, RR,SUM;
ask(l, r, 1, n, 1, SS, LL, RR,SUM);
printf("%lld\n",SS);
}
return 0;
}
参考代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 3e5 + 10;
typedef long long ll;
#define ls (x<<1)
#define rs (x<<1|1)
const int inf = 0x3f3f3f3f;
struct NODE
{
ll sum, lmax, rmax, smax;
NODE() { sum = 0; lmax = rmax = smax = 0; }
void set(ll a) { sum = lmax = rmax = smax = a; }
void init() { sum = 0; lmax = rmax = smax = -inf; }
//sum: 【x,x+1,x+2,...,y】的和。
//区间的前缀和最大值 lmax[x] 和后缀和最大值 rmax[x]
//区间答案smax
NODE friend operator + (NODE a, NODE b) {
NODE res;
res.sum = a.sum + b.sum;
res.smax = max(max(a.smax, b.smax), a.rmax + b.lmax);
res.lmax = max(a.sum + b.lmax, a.lmax);
res.rmax = max(b.sum + a.rmax, b.rmax);
return res;
}
}e[N << 2];
int a[N];
void upd(ll x) {
e[x] = e[ls] + e[rs];
}
void build(ll l, ll r, ll x) {
if (l == r) {
e[x].set(a[l]);
return;
}
ll mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
upd(x);
}
//在这里询问可以采用返回的形式或者是引用的形式
void ask(ll A, ll B, ll l, ll r, ll x, NODE& ans) {
if (A <= l && r <= B) { ans = e[x]; return; }
ll mid = (l + r) >> 1;
NODE resa, resb;
resa.init(); resb.init();
if (A <= mid)ask(A, B, l, mid, ls, resa);
if (mid < B)ask(A, B, mid + 1, r, rs, resb);
ans = resa + resb;
}
NODE ask(ll A, ll B, ll l, ll r, ll x) {
if (A <= l && r <= B) { return e[x]; }
ll mid = (l + r) >> 1;
NODE res; res.init();
if (A <= mid)res=res+ask(A, B, l, mid, ls);
if (mid < B)res=res+ask(A, B, mid + 1, r, rs);
return res;
}
int main() {
int n; cin >> n;
for (int i = 1; i <= n; i++)cin >> a[i];
build(1, n, 1);
int m; cin >> m;
while (m--) {
int l, r; cin >> l >> r;
if (l > r)swap(l, r);
NODE ans;
ask(l, r, 1, n, 1, ans);
cout << ans.smax << endl;
}
/*
while (m--) {
int l, r; cin >> l >> r;
if (l > r)swap(l, r);
NODE ans=ask(l, r, 1, n, 1);
cout << ans.smax << endl;
}
*/
return 0;
}
线段树解决离线询问
Turing Tree(HDU3333)
思路分析
思考直接用线段树来维护。
无法合并左右两个子节点。
因为不能统计每个节点出现了哪几种数
换一个思路,思考暴力做法。
枚举 x 到 y 中间的每个数,如果是重复的,那么就不加进
答案,否则加入。
如何判断是否重复?
我们可以用一个 flag 数组,表示这个数已经出现过了
for (int i = x;i <= y;i ++){
if (!flag[a[i]]) ans += a[i];
flag[a[i]] = 1;
}
还有一种思路,记录一个数组 left[i],表示左边第一个出
现的相同数字 a[i] 的下标。
这样如果 left[i] < l,就说明 a[i] 是 [x, y] 中第一个
出现的 a[i].
如果 left[i] >= l,就说明在 [x, i - 1] 中已经出现过
一次 a[i]了,不用累计进答案
for (int i = x;i <= y;i ++)
if (left[i] < x) ans += a[i];
预处理 left 数组:
STL中的map 或
离散化
从小到大枚举 x,把当前 left[i] < x 的 a[i] 插入线段
树中的第 i 个位置。
表明:对于任意的 y,询问 [x, y] 的话,如果 i 在 [x,
y]中,那一定有 left[i] < x,所以 a[i] 是要被统计进答
案的。
这样就做到了 O((n + q) log n) 的时间复杂度,离线解决
了所有询问
#include<bits/stdc++.h>
using namespace std;
const int N = 3e5 + 10;
typedef long long ll;
#define ls (x<<1)
#define rs (x<<1|1)
ll sum[N << 2], a[N];
ll ans[N];
struct event {
ll l, r, id;
void set(ll a, ll b, ll c) { l = a; r = b; id = c; }
void show() { cout << "l::" << l << "r::" << r << "id::" << id << endl; }
}p[N], q[N];
void upd(ll x) {
sum[x] = sum[ls] + sum[rs];
}
void build(ll l, ll r, ll x) {
sum[x] = 0;
if (l == r)return;
ll mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
upd(x);
}
void modify(ll pos, ll v, ll l, ll r, ll x) {
if (l == r) {
sum[x] += v;
return;
}
ll mid = (l + r) >> 1;
if (pos <= mid)modify(pos, v, l, mid, ls);
else modify(pos, v, mid + 1, r, rs);
upd(x);
}
ll ask(ll A, ll B, ll l, ll r, ll x) {
if (A <= l && r <= B)return sum[x];
ll mid = (l + r) >> 1;
ll res = 0;
if (A <= mid)res+=ask(A, B, l, mid, ls);
if (mid < B)res+=ask(A, B, mid + 1, r, rs);
return res;
}
bool cmp(event a, event b) {
return a.l < b.l;
}
int main() {
int t; scanf("%d", &t);
while (t--) {
int n; scanf("%d", &n);
build(1, n, 1);
for (int i = 1; i <= n; i++)scanf("%lld", &a[i]);
map<ll, ll> mp;
for (ll i = 1; i <= n; i++) {
p[i].set(mp[a[i]], i, a[i]);
mp[a[i]] = i;
}
int m; scanf("%d", &m);
for (int i = 1; i <= m; i++) {
scanf("%lld%lld", &q[i].l, &q[i].r); q[i].id = i;
}
sort(p + 1, p + 1 + n, cmp);
sort(q + 1, q + 1 +m, cmp);
int j = 1;
for (int i = 1; i <= m; i++) {
while (j <= n && p[j].l < q[i].l) { modify(p[j].r, p[j].id, 1, n, 1); j++; }
ans[q[i].id] = ask(q[i].l, q[i].r, 1, n, 1);
}
for (int i = 1; i <= m; i++)printf("%lld\n", ans[i]);
}
return 0;
}
题目来源:
SP1557 GSS2 - Can you answer these queries II
题目描述:
给出 n 个数,q 次询问,求最大子段和,相同的数只算一次。
输入格式
Line 1: integer N (1 <= N <= 100000);
Line 2: N integers denoting the score of each problem, each of them is a integer in range [-100000, 100000];
Line 3: integer Q (1 <= Q <= 100000);
Line 3+i (1 <= i <= Q): two integers X and Y denoting the _i_th question.
输出格式
Line i: a single integer, the answer to the _i_th question.
解题思路:
离线做。我们将所有询问按r从小到大排序。
我们一次从1到n扫过整个序列。假设现在扫到ii。在线段树中, 第jj个叶子结点我们维护a[j]…a[i]a[j]…a[i]序列的和Sum
#include<bits/stdc++.h>
using namespace std;
const int inf = 0x3f3f3f3f;
const int N = 1e5 + 10;
#define ls (x<<1)
#define rs (x<<1|1)
struct NODE {
int sum, mx, tagsum, tagmx;
//sum表示从这个叶结点所对应的原序列的下标x到y的所有元素和,及a[x]+a[x+1]+a[x+2]+...+a[y],
//mx:表示sum的历史最大值(最小为0)。
void show() { cout << "sum::" << sum << "mx::" << mx << "tagsum::" << tagsum << "tagmx::" << tagmx << endl; }
friend NODE operator + (NODE a, NODE b) {
NODE ans;
ans.sum = max(a.sum, b.sum);
ans.mx = max(a.mx, b.mx);
return ans;
}
NODE() { sum = mx = tagsum = tagmx = 0; }
}e[N << 2];
struct event {
int l, r, id;
}q[N];
int a[N],pre[N],ans[N];
int n, m;
void upd(int x) {
e[x] = e[ls] + e[rs];
}
void down(int x) {
e[ls].mx = max(e[ls].mx, e[ls].sum + e[x].tagmx);
e[ls].sum += e[x].tagsum;
e[ls].tagmx = max(e[ls].tagmx, e[ls].tagsum + e[x].tagmx);
e[ls].tagsum += e[x].tagsum;
e[rs].mx = max(e[rs].mx, e[rs].sum + e[x].tagmx);
e[rs].sum += e[x].tagsum;
e[rs].tagmx = max(e[rs].tagmx, e[rs].tagsum + e[x].tagmx);
e[rs].tagsum += e[x].tagsum;
e[x].tagsum = e[x].tagmx = 0;
}
void add(int A, int B, int v, int l, int r, int x) {
if (A <= l && r <= B) {
e[x].sum += v;
e[x].mx = max(e[x].mx, e[x].sum);
e[x].tagsum += v;
e[x].tagmx = max(e[x].tagmx, e[x].tagsum);
return;
}
down(x);
int mid = (l + r) >> 1;
if (A <= mid)add(A, B,v, l, mid, ls);
if (mid < B)add(A, B, v, mid + 1, r, rs);
upd(x);
}
int ask(int A, int B, int l, int r, int x) {
if (A <= l && r <= B)return e[x].mx;
down(x);
int mid = (l + r) >> 1;
int res = -inf;
if (A <= mid)res = max(res, ask(A, B, l, mid, ls));
if (mid < B)res = max(res, ask(A, B, mid + 1, r, rs));
return res;
}
bool cmp(event a, event b) { return a.r < b.r; }
void mspaint(int l,int r,int x) {
int mid = (l + r) >> 1;
cout << "l::" << l << "r::" << r << "x::" << x;
e[x].show();
if (l == r)return;
mspaint(l, mid, ls);
mspaint(mid + 1, r, rs);
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++)cin >> a[i];
map<int, int> mp;
for (int i = 1; i <= n; i++)pre[i] = mp[a[i]], mp[a[i]] = i;
cin >> m;
for (int i = 1; i <= m; i++)cin >> q[i].l >> q[i].r, q[i].id = i;
sort(q + 1, q + 1 + m,cmp);
int j = 1;
for (int i = 1; i <= n; i++) {
add(pre[i] + 1, i, a[i], 1, n, 1);
for (; j <= m && q[j].r <= i; j++)
ans[q[j].id] = ask(q[j].l, q[j].r, 1, n, 1);
}
for (int i = 1; i <= m; i++)cout << ans[i] << endl;
return 0;
}
SP1716 GSS3 - Can you answer these queries III
解题思路:
与GSS1不同多了一个修改操作。
与 GSS1 不同的是,多了一个修改操作。
void modify(ll pos,ll c, ll l, ll r, ll x) {
if (l == r) {
sum[x] = smax[x] = lmax[x] = rmax[x] = c;
return;
}
ll mid = (l + r) >> 1;
if (pos <= mid)modify(pos, c, l, mid, ls);
else modify(pos, c, mid + 1, r, rs);
upd(x);
}
题目来源:
SP2713 GSS4 - Can you answer these queries IV
题目描述:
解题思路:
注意到这题中,除了区间开方,没有要求区间修改。
也就是说,每个数都在不断变小。
就算是 1018,经过 7 次开方,也会变成 1.
因此,很多数在开方了几次之后,全都会变成 1.
因此我们只要对区间里的数暴力修改,最多每个数都会被暴
力修改 7 次,这部分对整个时间复杂度的贡献是 O(n) 的。
当然如果发现区间里全都是 1,那么就不用修改了。
对每个区间记录一个最大值。
如果当前区间的最大值不是 1,那么继续递归到需要修改的
子节点,暴力修改。
如果当前区间的最大值就是 1,那么可以终止递归。
这样一来,虽然可能存在某一次操作,需要暴力修改 n 个
数,时间复杂度会达到 O(n).
但是总的来看,每个数都只会被暴力最多改 7 次。
最后均摊的时间复杂度仍为 O(n log n)
参考代码:
#include<bits/stdc++.h>
using namespace std;
#define ls (x<<1)
#define rs (x<<1|1)
#define N 100010
typedef long long ll;
ll sum[N << 2], mx[N << 2],a[N];
int n, m;
void upd(int x) {
sum[x] = sum[ls] + sum[rs];
mx[x] = max(mx[ls],mx[rs]);//NOte:更新mx不是sum
}
void build(int l, int r, int x) {
if (l == r) {
sum[x] = mx[x] = a[l];
return;
}
int mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
upd(x);
}
void add(int A, int B, int l, int r, int x) {
if (l == r) {
sum[x] = mx[x] = (ll)(floor(sqrt((double)sum[x])));
return;
}
int mid = (l + r) >> 1;
if (A <= mid && mx[ls] > 1)add(A, B, l, mid, ls);
if (mid < B && mx[rs]>1)add(A, B, mid + 1, r, rs);
upd(x);
}
ll ask(int A, int B, int l, int r, int x) {
if (A <= l && r <= B)return sum[x];
int mid = (l + r) >> 1;
ll res = 0;
if (A <= mid)res += ask(A, B, l, mid, ls);
if (mid < B)res += ask(A, B, mid + 1, r, rs);
return res;
}
int main() {
int tc = 0;
while (scanf("%d", &n)!=EOF) {
printf("Case #%d:\n", ++tc);
for (int i = 1; i <= n; i++)scanf("%lld", &a[i]);
build(1, n, 1);
scanf("%d", &m);
int l, r, q;
while (m--) {
scanf("%d%d%d", &q, &l, &r);
if (l > r)swap(l, r);
if (q == 0)add(l, r, 1, n, 1);
else if(q==1) printf("%lld\n", ask(l, r, 1, n, 1));
}
printf("\n");
}
return 0;
}
题目来源:
SP2916 GSS5 - Can you answer these queries V
题目描述:
You are given a sequence A[1], A[2], …, A[N] . ( |A[i]| <= 10000 , 1 <= N <= 10000 ). A query is defined as follows: Query(x1,y1,x2,y2) = Max { A[i]+A[i+1]+…+A[j] ; x1 <= i <= y1 , x2 <= j <= y2 and x1 <= x2 , y1 <= y2 }. Given M queries (1 <= M <= 10000), your program must output the results of these queries.
你有一个长度为 n 的序列 A[1], A[2], …, A[N].
询问:
Query(x1, y1, x2, y2) = max { A[i] + … + A[j];
x1 <= i <= y1, x2 <= j <= y2}
x1 <= x2, y1 <= y2
给出 M 组操作,输出每次询问的答案
|A[i]| <= 10000, 1 <= N,M <= 10000
解题思路:
这次与以往不同的是,限定了左右端点的范围。
那么需要进行一下分类讨论。
如果 [x1, y1] 和 [x2, y2] 没有交集,即 y1 < x2
答案显然等于:
Rmax([x1, y1]) + Sum(y1 + 1, x2 - 1) + Lmax([x2, y2])
如果 [x1, y1] 和 [x2, y2] 有交集,即 y1 >= x2
这个时候,区间分为三个部分:
[x1, x2 - 1], [x2, y1], [y1 + 1 … y2]
左端点有两种选择,右端点也有两种选择,一共四种情况。
进一步讨论,变为三种情况:
Smax([x2, y1])
Rmax([x1, x2 - 1]) + Lmax([x2, y2])
Rmax([x1, y1]) + Lmax([y1 + 1 … y2])
参考代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 3e5 + 10;
typedef long long ll;
#define ls (x<<1)
#define rs (x<<1|1)
const int inf = 0x3f3f3f3f;
struct NODE
{
ll sum, lmax, rmax, smax;
NODE() { sum = 0; lmax = rmax = smax = 0; }
void set(ll a) { sum = lmax = rmax = smax = a; }
void init() { sum = 0; lmax = rmax = smax = -inf; }
//sum: 【x,x+1,x+2,...,y】的和。
//区间的前缀和最大值 lmax[x] 和后缀和最大值 rmax[x]
//区间答案smax
NODE friend operator + (NODE a, NODE b) {
NODE res;
res.sum = a.sum + b.sum;
res.smax = max(max(a.smax, b.smax), a.rmax + b.lmax);
res.lmax = max(a.sum + b.lmax, a.lmax);
res.rmax = max(b.sum + a.rmax, b.rmax);
return res;
}
}e[N << 2];
int a[N];
void upd(ll x) {
e[x] = e[ls] + e[rs];
}
void build(ll l, ll r, ll x) {
if (l == r) {
e[x].set(a[l]);
return;
}
ll mid = (l + r) >> 1;
build(l, mid, ls);
build(mid + 1, r, rs);
upd(x);
}
NODE ask(ll A, ll B, ll l, ll r, ll x) {
if (A <= l && r <= B) { return e[x]; }
ll mid = (l + r) >> 1;
NODE res; res.init();
if (A <= mid)res=res+ask(A, B, l, mid, ls);
if (mid < B)res=res+ask(A, B, mid + 1, r, rs);
return res;
}
int main() {
int t; cin >> t;
while (t--) {
int n; cin >> n;
for (int i = 1; i <= n; i++)cin >> a[i];
build(1, n, 1);
int m; cin >> m;
while (m--) {
int x1, y1, x2, y2;
cin >> x1 >> y1 >> x2 >> y2;
//进行分类讨论
if (x2 > y1) {
NODE a = ask(x1, y1, 1, n, 1);
NODE c = ask(x2, y2, 1, n, 1);
ll ans = a.rmax + c.lmax;
if (x2 - y1 > 1)ans += ask(y1 + 1, x2 - 1, 1, n, 1).sum;
cout << ans << endl;
}
else if (y1>=x2) {
ll ansa, ansb, ansc;
ansa = ansb = ansc = -inf;
NODE a = ask(x2, y1, 1, n, 1);
ansa = a.smax;
if(x2-1>=1)
ansb = ask(x1, x2 - 1, 1, n, 1).rmax + ask(x2, y2, 1, n, 1).lmax;
if(y1+1<=n)
ansc = ask(x1, y1, 1, n, 1).rmax + ask(y1 + 1, y2, 1, n, 1).lmax;
ll ans = max(max(ansa, ansb), ansc);
cout << ans << endl;
}
}
}
return 0;
}