基本操作
静态线段树、维护区间信息、优化信息可合并的DP等
拓展操作
线段树上二分
int quary( int p , int l , int r , int d ) // 当前节点 p , 询问区间为 [ql,qr], 查 >=d 的第一个下标
{
if( l <= t[p].l && t[p].r <= r ) { // 找到答案所在子节点,往叶子递归
if( t[p].maxx < d ) return n + 1 ;
if( t[p].l == t[p].r ) return t[p].l ; // 缩到一点
if( t[p<<1].maxx >= d ) return quary( p<<1 , l , r , d ) ;// 在左边
return quary( p<<1|1 , l , r , d ) ;
}
int mid = ( t[p].l + t[p].r ) >> 1 ;
if( r <= mid ) return quary( p<<1 , l , r , d ) ;
else if( l > mid ) return quary( p<<1|1 , l , r , d ) ;
else {
int pos = quary( p<<1 , l , r , d ) ; // 先向左,保证下标最小
if( pos == n + 1 ) return quary( p<<1|1 , l , r , d ) ;
return pos ;
}
}
核心思想是 先将 二分区间 转化为 l o g 2 n log_2n log2n 个线段树节点,通过 节点整体 的信息 与 所求值 比较,快速找到答案所在的 线段树节点
接下来在这个节点中递归,分别考虑答案在 左 o r or or右子树 中,最多 l o g 2 n log_2n log2n 层到达到叶节点
类似于分块思想,将单点信息合并为整块处理,前提是答案具有单调性( 指的是子节点对父节点的单调性)
动态开点
某些问题中,给出区间的总长度特别长,但实际上大部分节点都是无用的,因此只在操作时将需要用到的节点在线段树上建出来
struct Segtree
{
int ls , rs , sum ;
}t[N*32] ; // 这里空间要开到 (操作次数)*log(值域)
int root , tot ; // 节点数量
inline int build() // 创建一个新的节点 并返回编号
{
tot ++ ;
t[tot].ls = t[tot].rs = t[tot].sum = 0 ;
return tot ;
}
void insert( int p , int l , int r , int x , int d )// l,r 为理论区间左右端点,在 x 位置加 d
{
if( l == r ) {
t[p].sum += d ;
return ;
}
int mid = ( l + r ) >> 1 ;
if( x <= mid ) {
if( !t[p].ls ) t[p].ls = build() ;
insert( t[p].ls , l , mid , x , d ) ;
}
else {
if( !t[p].rs ) t[p].rs = build() ;
insert( t[p].rs , mid+1 , r , x , d ) ;
}
t[p].sum = t[t[p].ls].sum + t[t[p].rs].sum ;
}
int ask( int p , int l , int r , int lq , int rq )
{
if( lq <= l && r <= rq ) {
return t[p].sum ;
}
int mid = ( l + r ) >> 1 , res = 0 ;
if( lq <= mid ) {
if( !t[p].ls ) t[p].ls = build() ;
res += ask( t[p].ls , l , mid , lq , rq ) ;
}
if( rq > mid ) {
if( !t[p].rs ) t[p].rs = build() ;
res += ask( t[p].rs , mid+1 , r , lq , rq ) ;
}
t[p].sum = t[t[p].ls].sum + t[t[p].rs].sum ;
return res ;
}
写代码时理解成一棵 “虚拟的线段树”,递归前判断一下左右儿子是否存在即可
需要注意的几点:
1.值域线段树 注意 右端点可能不是
n
n
n !
2.动态开点对空间要求很大, (操作次数)*(总长度)
结合其他算法 (如DP) 时可以有效地节省空间 例[USACO15FEB Cow Hopscotch G]
线段树合并
需要结合动态开点,因为合并时要保证两棵线段树结构一致,因此通常用在 值域线段树 中
可以很高效地维护合并过程
int merge( int p , int q , int l , int r ) // 同时递归两棵线段树,节点为 p , q
{
if( !p ) return q ;
if( !q ) return p ;
if( l == r ) { // 到达叶子,合并
t[p].val += t[q].val ;
return p ;
}
int mid = ( l + r ) >> 1 ;
t[p].ls = merge( t[p].ls , t[q].ls , l , mid ) ;
t[p].rs = merge( t[p].rs , t[q].rs , mid+1 , r ) ;
update(p) ; // 递归向上更新
return p ; // 节点 q 的信息被添加到 p中,只返回 p 即可
}
两棵线段树合并 的时间复杂度为 两棵子树的节点交 ,空间复杂度为 两棵子树节点并(画图理解)
但多个合并时整体空间仍要开到 (操作次数)*log(长度) ,因为合并过程中会有某些节点重复出现
基本例题 雨天的尾巴
拓展一点,线段树合并的过程中实际上还可以很方便地 统计两棵树之间构成某种关系的信息( eg : 逆序对)
例 [POI2011] ROT-Tree Rotations
观察这个 Merge 函数
int Merge( int p , int q , int l , int r , LL &ans )
{
if( !p ) return q ;
if( !q ) return p ;
if( l == r ) {
t[p].sum += t[q].sum ;
return p ;
}
int mid = ( l + r ) >> 1 ;
ans += 1LL * t[t[p].ls].sum * t[t[q].rs].sum ; // 统计逆序对
t[p].ls = Merge( t[p].ls , t[q].ls , l , mid , ans ) ;
t[p].rs = Merge( t[p].rs , t[q].rs , mid+1 , r , ans ) ;
t[p].sum = t[t[p].ls].sum + t[t[p].rs].sum ;
return p ;
}
线段树优化建图
鸽一下下~~
主席树
区间拷贝
为 避免当前的拷贝操作 对原来在 q树 上进行的操作 产生影响,将 q 拷贝,不影响原来;同时在新树上进行本次操作
int copy( int p , int q , int l , int r , int lq , int rq ) // 将 q 拷贝一份,p 的 [l,r] 复制到拷贝后的
{
if( !p ) return 0 ;
if( lq <= l && r <= rq ) {
return p ;
}
int nw = build() ; t[nw] = t[q] ; // 拷贝 q
int mid = ( l + r ) >> 1 ;
if( lq <= mid ) t[nw].ls = copy( t[p].ls , t[nw].ls , l , mid , lq , rq ) ;
if( rq > mid ) t[nw].rs = copy( t[p].rs , t[nw].rs , mid+1 , r , lq , rq ) ;
t[nw].sum = t[t[nw].ls].sum + t[t[nw].rs].sum ;
return nw ;
}
维护前缀信息,区间查询需要在值域上完成
如求第 K K K 大
#include<bits/stdc++.h>
using namespace std ;
typedef long long LL ;
const int N = 1e6 + 10 ;
int read()
{
int x = 0 ; char c = getchar() ;
while( !isdigit(c) ) c = getchar() ;
while( isdigit(c) ) x = (x<<1)+(x<<3)+(c^48) , c = getchar() ;
return x ;
}
int n , a[N] , lst[N] , bar[N] , m ;
struct Segtree
{
int ls , rs , sum ;
}t[24*N] ;
int tot , root[N] ;
int build() { return ++tot ; }
void build0( int p , int l , int r )
{
if( l == r ) {
return ;
}
int mid = ( l + r ) >> 1 ;
t[p].ls = build() ;
build0( t[p].ls , l , mid ) ;
t[p].rs = build() ;
build0( t[p].rs , mid+1 , r ) ;
}
int Insert( int p , int l , int r , int x )
{
int nw = build() ;
t[nw] = t[p] ;
if( l == r ) {
t[nw].sum ++ ;
return nw ;
}
int mid = ( l + r ) >> 1 ;
if( x <= mid ) t[nw].ls = Insert( t[p].ls , l , mid , x ) ;
else t[nw].rs = Insert( t[p].rs , mid+1 , r , x ) ;
t[nw].sum = t[t[nw].ls].sum + t[t[nw].rs].sum ;
return nw ;
}
int query( int p , int q , int l , int r , int x ) // sum_p - sum_q
{
if( 1 <= l && r <= x ) {
return t[p].sum - t[q].sum ;
}
int mid = ( l + r ) >> 1 , res = 0 ;
if( 1 <= mid ) res += query( t[p].ls , t[q].ls , l , mid , x ) ;
if( x > mid ) res += query( t[p].rs , t[q].rs , mid+1 , r , x ) ;
return res ;
}
int main()
{
n = read() ;
for(int i = 1 ; i <= n ; i ++ ) {
a[i] = read() ;
lst[i] = bar[a[i]] ;
bar[a[i]] = i ;
}
root[0] = build() ;
build0( root[0] , 1 , n ) ;
for(int i = 1 ; i <= n ; i ++ ) {
root[i] = Insert( root[i-1] , 1 , n , lst[i]+1 ) ;
}
m = read() ;
int l , r ;
for(int i = 1 ; i <= m ; i ++ ) {
scanf("%d%d" , &l , &r ) ;
printf("%d\n" , query( root[r] , root[l-1] , 1 , n , l ) ) ;
}
return 0 ;
}