两道题的链接
洛谷 https://www.luogu.com.cn/problem/P3834
牛客 https://ac.nowcoder.com/acm/problem/19427
本来是写专题写到自闭想去牛客写点简单的题目放松一下
结果找到了这道题
拿线段树写了下 t掉了
看题解说是主席树裸题
然后就开始了一整天的主席树学习
又是自闭的一天
到现在其实对主席树的板子的理解还是有一点模模糊糊的
推荐一个主席树视频吧
https://www.bilibili.com/video/BV1C4411u7rK?p=2
让我从完全看不懂代码到现在基本上能够理解了
主席树真是一个让人自闭的数据结构
它是由多个权值线段树优化而来
用来记录一个权值线段树的所有历史版本
权值线段树的博客推荐
https://www.cnblogs.com/fusiwei/p/12234435.html
以下是主席树的板子分解
首先是应为主席树是权值线段树变来的
所以叶子节点记录的都是某个数值出现的次数
但题目数据过大时我们就需要使用离散化来处理线段树的数据
vector<int>v;
int getid(int x)
{
return lower_bound(v.begin(),v.end(),x)-v.begin()+1;
}
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
v.push_back(a[i]);
}
sort(v.begin(),v.end());
v.erase(unique(v.begin(),v.end()),v.end());
//之后直接用getid函数即可获得离散化后的数据
然后是关于主席树的操作
首先是建树
void Insert(int l,int r,int pre,int &now,int p)//这里的now表示的是新建的树 直接在数组里面赋值 所以要有&
{
hjt[++cnt]=hjt[pre];//将之前的树的数据拷贝给新的
now=cnt;
hjt[now].sum++;//含有新树插入值的区间需要加一
//之后就是线段树的正常建树方法
if(l==r)return;
int mid=(l+r)>>1;
if(p<=mid)
Insert(l,mid,hjt[pre].l,hjt[now].l,p);
else
Insert(mid+1,r,hjt[pre].r,hjt[now].r,p);
}
然后是查询
int query(int l,int r,int L,int R,int k)
{
if(l==r) return l;
int mid=(l+r)>>1;
//令我们要求的区间的右端点为r 左端点为l 将r的树减去l-1的树 剩下的就是从l到r的树了 可以类比前缀和的思想进行理解 这里用的是l是因为在查询时做了处理
int tem=hjt[hjt[R].l].sum-hjt[hjt[L].l].sum;
if(k<=tem)return query(l,mid,hjt[L].l,hjt[R].l,k);//与权值线段树的查询方式相同
else return query(mid+1,r,hjt[L].r,hjt[R].r,k-tem);
}
主席树初步的代码就这么多 主要还是难在理解它建树时与之前树连接起来的操作
洛谷p3834 ac代码
#include <stdio.h>
#include <iostream>
#include <algorithm>
#include <math.h>
#include <string.h>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <utility>
#define pi 3.1415926535898
#define ll long long
#define lson rt<<1
#define rson rt<<1|1
#define eps 1e-6
#define ms(a,b) memset(a,b,sizeof(a))
#define legal(a,b) a&b
#define print1 printf("111\n")
using namespace std;
const int maxn = 2e5+10;
const int inf = 0x1f1f1f1f;
const int mod = 2333;
int a[maxn];
int n,m;
struct node
{
int l,r,sum;
}hjt[maxn<<5];
int cnt,root[maxn];
vector<int>v;
int getid(int x)
{
return lower_bound(v.begin(),v.end(),x)-v.begin()+1;
}
void Insert(int l,int r,int pre,int &now,int p)
{
hjt[++cnt]=hjt[pre];
now=cnt;
hjt[now].sum++;
if(l==r)return;
int mid=(l+r)>>1;
if(p<=mid)
Insert(l,mid,hjt[pre].l,hjt[now].l,p);
else
Insert(mid+1,r,hjt[pre].r,hjt[now].r,p);
}
int query(int l,int r,int L,int R,int k)
{
if(l==r) return l;
int mid=(l+r)>>1;
int tem=hjt[hjt[R].l].sum-hjt[hjt[L].l].sum;
if(k<=tem)return query(l,mid,hjt[L].l,hjt[R].l,k);
else return query(mid+1,r,hjt[L].r,hjt[R].r,k-tem);
}
int main()
{
//freopen("input.txt","r",stdin);
//freopen("output.txt","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
v.push_back(a[i]);
}
sort(v.begin(),v.end());
v.erase(unique(v.begin(),v.end()),v.end());
for(int i=1;i<=n;i++)
{
Insert(1,n,root[i-1],root[i],getid(a[i]));
}
while(m--)
{
int l,r,k;
scanf("%d%d%d",&l,&r,&k);
int tem=query(1,n,root[l-1],root[r],k);
printf("%d\n",v[tem-1]);//a数组是从1开始 而vector则是从0开始 所以要减1
}
}
牛客 换个角度思考
这个题因为问的是a[i]<=k的个数
所以不离散化
直接查询
查询的函数也有些变化
ac代码
#include <stdio.h>
#include <iostream>
#include <algorithm>
#include <math.h>
#include <string.h>
#include <vector>
#include <stack>
#include <queue>
#include <map>
#include <utility>
#define pi 3.1415926535898
#define ll long long
#define lson rt<<1
#define rson rt<<1|1
#define eps 1e-6
#define ms(a,b) memset(a,b,sizeof(a))
#define legal(a,b) a&b
#define print1 printf("111\n")
using namespace std;
const int maxn = 2e5+10;
const int inf = 0x1f1f1f1f;
const int mod = 2333;
int a[maxn];
int n,m;
struct node
{
int l,r,sum;
}hjt[maxn<<5];
int cnt,root[maxn];
vector<int>v;
int getid(int x)
{
return lower_bound(v.begin(),v.end(),x)-v.begin()+1;
}
void Insert(int l,int r,int pre,int &now,int p)
{
hjt[++cnt]=hjt[pre];
now=cnt;
hjt[now].sum++;
if(l==r)return;
int mid=(l+r)>>1;
if(p<=mid)
Insert(l,mid,hjt[pre].l,hjt[now].l,p);
else
Insert(mid+1,r,hjt[pre].r,hjt[now].r,p);
}
int query(int l,int r,int pre,int now,int L,int R)
{
if(r<=R) return hjt[now].sum-hjt[pre].sum;
int mid=(l+r)>>1;
if(R<=mid)return query(l,mid,hjt[pre].l,hjt[now].l,L,R);
else if(L>mid)return query(mid+1,r,hjt[pre].r,hjt[now].r,L,R);
else return query(l,mid,hjt[pre].l,hjt[now].l,L,mid)+query(mid+1,r,hjt[pre].r,hjt[now].r,mid+1,R);
}
int main()
{
//freopen("input.txt","r",stdin);
//freopen("output.txt","w",stdout);
scanf("%d%d",&n,&m);
int maxx=0;
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
maxx=max(maxx,a[i]);
v.push_back(a[i]);
}
sort(v.begin(),v.end());
v.erase(unique(v.begin(),v.end()),v.end());
for(int i=1;i<=n;i++)
{
Insert(1,maxx,root[i-1],root[i],getid(a[i]));
}
while(m--)
{
int l,r,k;
scanf("%d%d%d",&l,&r,&k);
int tem=query(1,maxx,root[l-1],root[r],1,k);
printf("%d\n",tem);
}
}