题意
解法
首先考虑一个
O
(
n
2
)
O(n^2)
O(n2)的解法,我们先考虑求出不含任何星星的矩形个数,这个可以考虑单调栈做法,具体而言,对于每个右下角,它的左上角形成的图形是一个台阶的形状,然后就枚举每一行,然后首先计算每个位置出现星星的最晚时间,用单调栈维护每个位置向后第一个出现星星时间比当前时间晚的位置,转移的时候对一些区间加等差数列.
然后考虑更快的做法:将单调栈换成线段树,然后维护同样的问题.我第一次写的时候想一颗线段树直接把单调栈维护出来,但是定义的很有问题,导致其不单调了.所以最后还是用一棵线段树维护区间min,另一棵线段树维护当前行的情况.
最后用总数减去这些不合法的就好了.
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn=7e4+5;
inline char get_char(){//超级快读
static char buf[1000001],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++;
}
inline int read(){
int num=0;
char c;
while(isspace(c=get_char()));
while(num=num*10+c-48,isdigit(c=get_char()));
return num;
}
int n,m,q;
vector<int> a[maxn];
struct node{
long long sum,tag,tag3;
}t[maxn<<2];
int tr[maxn<<2],tr2[maxn];
#define ls rt<<1
#define rs rt<<1|1
inline void add(int rt,int l,int r,int x,int y){
if(x<=l&&r<=y){
t[rt].sum=t[rt].sum+((l-x+1)+(r-x+1))*(r-l+1)/2;
t[rt].tag++;t[rt].tag3-=x;
return ;
}
if(t[rt].tag){
int mid=(l+r)>>1;
t[ls].sum=t[ls].sum+((l*t[rt].tag+t[rt].tag)+(mid*t[rt].tag+t[rt].tag))*(mid-l+1)/2;
t[ls].tag+=t[rt].tag;
t[rs].sum=t[rs].sum+(((mid+1)*t[rt].tag+t[rt].tag)+(r*t[rt].tag+t[rt].tag))*(r-mid)/2;
t[rs].tag+=t[rt].tag;
t[rt].tag=0;
}
if(t[rt].tag3){
int mid=(l+r)>>1;
t[ls].sum=t[ls].sum+t[rt].tag3*(mid-l+1);t[ls].tag3+=t[rt].tag3;
t[rs].sum=t[rs].sum+t[rt].tag3*(r-mid);t[rs].tag3+=t[rt].tag3;
t[rt].tag3=0;
}
int mid=(l+r)>>1;
if(x<=mid)add(ls,l,mid,x,y);
if(y>mid)add(rs,mid+1,r,x,y);
t[rt].sum=t[ls].sum+t[rs].sum;
}
inline long long query(int rt,int l,int r,int x,int y){
if(x<=l&&r<=y){
return t[rt].sum;
}
if(t[rt].tag){
int mid=(l+r)>>1;
t[ls].sum=t[ls].sum+((l*t[rt].tag+t[rt].tag)+(mid*t[rt].tag+t[rt].tag))*(mid-l+1)/2;
t[ls].tag+=t[rt].tag;
t[rs].sum=t[rs].sum+(((mid+1)*t[rt].tag+t[rt].tag)+(r*t[rt].tag+t[rt].tag))*(r-mid)/2;
t[rs].tag+=t[rt].tag;
t[rt].tag=0;
}
if(t[rt].tag3){
int mid=(l+r)>>1;
t[ls].sum=t[ls].sum+t[rt].tag3*(mid-l+1);t[ls].tag3+=t[rt].tag3;
t[rs].sum=t[rs].sum+t[rt].tag3*(r-mid);t[rs].tag3+=t[rt].tag3;
t[rt].tag3=0;
}
int mid=(l+r)>>1;
long long ans=0;
if(x<=mid)ans=ans+query(ls,l,mid,x,y);
if(y>mid)ans=ans+query(rs,mid+1,r,x,y);
return ans;
}
inline void modify(int rt,int l,int r,int x,int y,long long val){
if(x<=l&&r<=y){
t[rt].sum=t[rt].sum+val*(r-l+1);
t[rt].tag3+=val;
return ;
}
if(t[rt].tag){
int mid=(l+r)>>1;
t[ls].sum=t[ls].sum+((l*t[rt].tag+t[rt].tag)+(mid*t[rt].tag+t[rt].tag))*(mid-l+1)/2;
t[ls].tag+=t[rt].tag;
t[rs].sum=t[rs].sum+(((mid+1)*t[rt].tag+t[rt].tag)+(r*t[rt].tag+t[rt].tag))*(r-mid)/2;
t[rs].tag+=t[rt].tag;
t[rt].tag=0;
}
if(t[rt].tag3){
int mid=(l+r)>>1;
t[ls].sum=t[ls].sum+t[rt].tag3*(mid-l+1);t[ls].tag3+=t[rt].tag3;
t[rs].sum=t[rs].sum+t[rt].tag3*(r-mid);t[rs].tag3+=t[rt].tag3;
t[rt].tag3=0;
}
int mid=(l+r)>>1;
if(x<=mid)modify(ls,l,mid,x,y,val);
if(y>mid) modify(rs,mid+1,r,x,y,val);
t[rt].sum=t[ls].sum+t[rs].sum;
}
void build(int rt,int l,int r){
if(l==r){
tr[rt]=l;
return ;
}
int mid=(l+r)>>1;
build(ls,l,mid);build(rs,mid+1,r);
if(tr2[tr[ls]]>tr2[tr[rs]])tr[rt]=tr[ls];
else tr[rt]=tr[rs];
}
inline int quert(int rt,int l,int r,int ql,int qr,int x){
if(tr2[tr[rt]]<=x){return -1;}
if(l==r){
return l;
}
int mid=(l+r)>>1;
int tmp=-1;
if(tr2[tr[ls]]>x&&ql<=mid)tmp=quert(ls,l,mid,ql,qr,x);
if(tmp!=-1)return tmp;
if(tr2[tr[rs]]>x&&qr>=mid+1)tmp=quert(rs,mid+1,r,ql,qr,x);
return tmp;
}
inline void modi(int rt,int l,int r,int x){
if(l==r){
return ;
}
int mid=(l+r)>>1;
if(x<=mid)modi(ls,l,mid,x);
else modi(rs,mid+1,r,x);
if(tr2[tr[ls]]>tr2[tr[rs]])tr[rt]=tr[ls];
else tr[rt]=tr[rs];
}
void change(int rt,int l,int r,int x){
if(l==r){
t[rt].sum=0;
t[rt].tag=t[rt].tag3=0;
return ;
}
if(t[rt].tag){
int mid=(l+r)>>1;
t[ls].sum=t[ls].sum+((l*t[rt].tag+t[rt].tag)+(mid*t[rt].tag+t[rt].tag))*(mid-l+1)/2;
t[ls].tag+=t[rt].tag;
t[rs].sum=t[rs].sum+(((mid+1)*t[rt].tag+t[rt].tag)+(r*t[rt].tag+t[rt].tag))*(r-mid)/2;
t[rs].tag+=t[rt].tag;
t[rt].tag=0;
}
if(t[rt].tag3){
int mid=(l+r)>>1;
t[ls].sum=t[ls].sum+t[rt].tag3*(mid-l+1);t[ls].tag3+=t[rt].tag3;
t[rs].sum=t[rs].sum+t[rt].tag3*(r-mid);t[rs].tag3+=t[rt].tag3;
t[rt].tag3=0;
}
int mid=(l+r)>>1;
if(x<=mid)change(ls,l,mid,x);
else change(rs,mid+1,r,x);
t[rt].sum=t[ls].sum+t[rs].sum;
}
signed main(){
//freopen("1.in","r",stdin);
//freopen("1.out","w",stdout);
n=read(),m=read();q=read();
for(int i=1;i<=q;i++){
int x=read(),y=read();
a[x].push_back(y);
}
for(int i=1;i<=n;i++){a[i].push_back(m+1);a[i].push_back(0);sort(a[i].begin(),a[i].end());}
long long ans=0;tr2[m+1]=1e9;
build(1,1,m+1);
for(int i=1;i<=n;i++){
for(int j=1;j<a[i].size();j++){
int pos=a[i][j-1]+1;
while(pos<a[i][j]){
int lst=quert(1,1,m+1,pos+1,m+1,tr2[pos]),ht=tr2[pos];
int tmp=query(1,1,m,pos,pos);
if(tmp-(pos-a[i][j-1])*(i-ht-1)){
modify(1,1,m,pos,min(lst-1,a[i][j]-1),-(tmp-(pos-a[i][j-1])*(i-ht-1)));}
pos=lst;
}
if(a[i][j-1]+1<=a[i][j]-1){
add(1,1,m,a[i][j-1]+1,a[i][j]-1);
}
}
for(int j=1;j<a[i].size()-1;j++){int pos=a[i][j];
tr2[pos]=i;change(1,1,m,pos);
modi(1,1,m+1,pos);
}
ans=ans+t[1].sum;
}
long long tot1=n*(n+1)/2,tot2=m*(m+1)/2;
printf("%lld\n",tot1*tot2-ans);
return 0;
}