需要算所有跟查询矩形相交的矩形(两个点组成的)
只要把所有矩形的数量减去不相交的矩形就好了
显然不相交的矩形是有在查询矩形的上、下、左、右的两个点组成
但是会有重复的部分即在左上左下右上右下的部分
所以再把答案加回这部分
利用主席树查询[l,r]上[d,u]的数有多少个
#include <iostream>
#include <algorithm>
#include <sstream>
#include <string>
#include <queue>
#include <cstdio>
#include <map>
#include <set>
#include <utility>
#include <stack>
#include <cstring>
#include <cmath>
#include <vector>
#include <ctime>
#include <bitset>
using namespace std;
#define pb push_back
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d%d",&n,&m)
#define sddd(n,m,k) scanf("%d%d%d",&n,&m,&k)
#define sld(n) scanf("%lld",&n)
#define sldd(n,m) scanf("%lld%lld",&n,&m)
#define slddd(n,m,k) scanf("%lld%lld%lld",&n,&m,&k)
#define sf(n) scanf("%lf",&n)
#define sff(n,m) scanf("%lf%lf",&n,&m)
#define sfff(n,m,k) scanf("%lf%lf%lf",&n,&m,&k)
#define ss(str) scanf("%s",str)
#define ans() printf("%d",ans)
#define ansn() printf("%d\n",ans)
#define anss() printf("%d ",ans)
#define lans() printf("%lld",ans)
#define lanss() printf("%lld ",ans)
#define lansn() printf("%lld\n",ans)
#define fansn() printf("%.10f\n",ans)
#define r0(i,n) for(int i=0;i<(n);++i)
#define r1(i,e) for(int i=1;i<=e;++i)
#define rn(i,e) for(int i=e;i>=1;--i)
#define rsz(i,v) for(int i=0;i<(int)v.size();++i)
#define szz(x) ((int)x.size())
#define mst(abc,bca) memset(abc,bca,sizeof abc)
#define lowbit(a) (a&(-a))
#define all(a) a.begin(),a.end()
#define pii pair<int,int>
#define pli pair<ll,int>
#define mp make_pair
#define lrt rt<<1
#define rrt rt<<1|1
#define X first
#define Y second
#define PI (acos(-1.0))
#define sqr(a) ((a)*(a))
typedef long long ll;
typedef unsigned long long ull;
const ll mod = 1000000000+7;
const double eps=1e-9;
const int inf=0x3f3f3f3f;
const ll infl = 10000000000000000;
const int maxn= 200000+10;
const int maxm = maxn*21+10;
//Pretests passed
int in(int &ret)
{
char c;
int sgn ;
if(c=getchar(),c==EOF)return -1;
while(c!='-'&&(c<'0'||c>'9'))c=getchar();
sgn = (c=='-')?-1:1;
ret = (c=='-')?0:(c-'0');
while(c=getchar(),c>='0'&&c<='9')ret = ret*10+(c-'0');
ret *=sgn;
return 1;
}
int root[maxn];
struct Seg
{
int cnt,lch,rch;
}seg[maxm];
int tot;
void update(int &rt,int l,int r,int x)
{
int last = rt;
seg[rt=++tot] = seg[last];
++seg[rt].cnt;
if(l==r)return ;
int mid = (l+r)>>1;
if(x<=mid)update(seg[rt].lch,l,mid,x);
else update(seg[rt].rch,mid+1,r,x);
}
int query(int rt,int L,int R,int l,int r)
{
if(l<=L&&R<=r)return seg[rt].cnt;
int m = (L+R)>>1;
if(r<=m)return query(seg[rt].lch,L,m,l,r);
if(m<l)return query(seg[rt].rch,m+1,R,l,r);
return query(seg[rt].lch,L,m,l,r) + query(seg[rt].rch,m+1,R,l,r);
}
int n;
int query(int xl,int xr,int yl,int yr)
{
return query(root[xr],1,n,yl,yr) - query(root[xl-1],1,n,yl,yr);
}
ll cal(int x)
{
return (1LL*(x-1)*x)>>1;
}
int main()
{
#ifdef LOCAL
freopen("input.txt","r",stdin);
// freopen("output.txt","w",stdout);
#endif // LOCAL
int q;
sdd(n,q);
r1(i,n)
{
int x;
sd(x);
root[i] = root[i-1];
update(root[i],1,n,x);
}
while(q--)
{
int x1,y1,x2,y2;
sdd(x1,y1),sdd(x2,y2);
ll ans = cal(n);
ans -= cal(x1-1) + cal(n-x2) + cal(y1-1) +cal(n-y2);
int lu = 1<x1&&1<y1? query(1,x1-1,1,y1-1) : 0 ;
int ld = 1<x1&&y2<n? query(1,x1-1,y2+1,n) : 0 ;
int ru = x2<n&&y1>1? query(x2+1,n,1,y1-1) : 0 ;
int rd = x2<n&&y2<n? query(x2+1,n,y2+1,n) : 0 ;
ans += cal(lu) + cal(ld) + cal(ru) + cal(rd);
lansn();
}
return 0;
}