Description
给你一个N*N的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第K小数。
Input
第一行两个数N,Q,表示矩阵大小和询问组数;
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。
接下来N行N列一共N*N个数,表示这个矩阵;
再接下来Q行每行5个数描述一个询问:x1,y1,x2,y2,k表示找到以(x1,y1)为左上角、以(x2,y2)为右下角的子矩形中的第K小数。
Output
对于每组询问输出第K小的数。
Sample Input
2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3
2 1
3 4
1 2 1 2 1
1 1 2 2 3
Sample Output
1
3
3
HINT
矩阵中数字是109以内的非负整数;20%的数据:N<=100,Q<=1000;
40%的数据:N<=300,Q<=10000;
60%的数据:N<=400,Q<=30000;
100%的数据:N<=500,Q<=60000。
Solve
整体二分,二分答案,二位树状数组求区间数的个数。
#include<algorithm>
#include<iostream>
#include<cstdio>
#define lowbit(x) (x&(-x))
using namespace std;
int sum[505][505],ss,n,m,cnt,xh[60005],ans[60005],step,tmp[60005],maxn;
bool mark[60005];
struct node{
int val,x,y;
friend bool operator < (node i,node j){
return i.val<j.val;
}
}t[250005];
struct orz{int x1,x2,y1,y2,k;}q[60005];
inline void add(int x,int y,int key){
for (int i=x;i<=n;i+=lowbit(i))
for (int j=y;j<=n;j+=lowbit(j))
sum[i][j]+=key;
}
inline int get_sum(int x,int y){
ss=0;
for (int i=x;i;i-=lowbit(i))
for (int j=y;j;j-=lowbit(j))
ss+=sum[i][j];
return ss;
}
inline int query(orz i){
return get_sum(i.x2,i.y2)+get_sum(i.x1-1,i.y1-1)-get_sum(i.x1-1,i.y2)-get_sum(i.x2,i.y1-1);
}
inline void solve(int l,int r,int L,int R){
if (l>r || L==R)return;
int mid=(L+R)>>1,len=0,ll[2];
for (;step<cnt && t[step+1].val<=mid;++step)add(t[step+1].x,t[step+1].y,1);
for (;step>0 && t[step].val>mid;--step)add(t[step].x,t[step].y,-1);
for (int i=l;i<=r;++i)
if (query(q[xh[i]])>=q[xh[i]].k)
ans[xh[i]]=mid,mark[i]=++len;
else mark[i]=0;
ll[1]=l;ll[0]=l+len;
for (int i=l;i<=r;++i)tmp[ll[mark[i]]++]=xh[i];
for (int i=l;i<=r;++i)xh[i]=tmp[i];
solve(l,ll[1]-1,L,mid);solve(ll[1],r,mid+1,R);
}
int main (){
scanf ("%d%d",&n,&m);
for (int i=1;i<=n;++i)
for (int j=1;j<=n;++j){
scanf ("%d",&t[++cnt].val);
t[cnt].x=i;t[cnt].y=j;
maxn=max(maxn,t[cnt].val);
}
sort(t+1,t+cnt+1);
for (int i=1;i<=m;++i)
scanf ("%d%d%d%d%d",&q[(xh[i]=i)].x1,&q[i].y1,&q[i].x2,&q[i].y2,&q[i].k);
solve(1,m,0,maxn+1);
for (int i=1;i<=m;++i)printf ("%d\n",ans[i]);
return 0;
}