题目链接:戳这里
2509: 送分题
Time Limit: 10 Sec Memory Limit: 128 MBSubmit: 62 Solved: 26
[ Submit][ Status][ Discuss]
Description
给出平面上的M条平行于坐标轴的线段,问有多少个正方形。
Input
第1行为两个正整数N,M。
接下来M行,每行4个非负整数x1,y1,x2,y2(0≤x1≤x2≤N,0≤y1≤y2≤N),描述了线段的两个端点。
Output
仅包括一个正整数,为平面上正方形的个数。
Sample Input
3 8
0 0 0 3
1 0 1 3
2 0 2 2
3 0 3 3
0 0 3 0
0 1 3 1
0 2 3 2
0 3 3 3
0 0 0 3
1 0 1 3
2 0 2 2
3 0 3 3
0 0 3 0
0 1 3 1
0 2 3 2
0 3 3 3
Sample Output
11
HINT
【样例说明】
样例对应了如下一张图
其中边长为1的正方形有7个,边长为2的正方形有3个,边长为3的正方形有1个。
所以答案为7+3+1=11。
【数据规模】
对于20%的数据,有N≤30;
对于40%的数据,有N≤100;
对于60%的数据,有N≤800;
对于100%的数据,有N≤1000,M≤400000,并保证任意两条线段没有重合部分。
考场上打了60分暴力结果被自己hack掉,sad。最终爆0
【正文】
这题我好像说不太清...
对于对角线上的点,能构成正方形的个数是在其前加入且在这个点能伸展到范围的点的个数。
因此定义l[x][y],r[x][y],u[x][y],d[x][y],分别统计每个点上下左右延伸的距离,然后枚举每条对角线,用树状数组统计答案。
代码:
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cstdio>
#define maxn 2001
using namespace std;
typedef long long LL;
int read()
{
char c;int sum=0,f=1;c=getchar();
while(c<'0' || c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0' && c<='9'){sum=sum*10+c-'0';c=getchar();}
return sum*f;
}
int n,m,l[maxn][maxn],r[maxn][maxn],u[maxn][maxn],d[maxn][maxn],a[maxn][maxn],b[maxn][maxn];
LL ans;
struct BIT{
int c[maxn];
bit(){}
void clr(){memset(c,0,sizeof(c));}
void add(int x,int v)
{
for(;x<=n;x+=x&-x)
c[x]+=v;
}
int sum(int x)
{
int ret=0;
for(;x;x-=x&-x) ret+=c[x];
return ret;
}
}bit;
struct node{
int k,pos1,pos2;
node(){}
node(int k,int pos1,int pos2):k(k),pos1(pos1),pos2(pos2){}
bool operator < (const node & A)const{return pos1==A.pos1?k<A.k:pos1<A.pos1;}
}scan[100010];
void solve(int x,int y)
{
int tot=0;
for(;x<=n && y<=n;x++,y++)
{
scan[++tot]=node(-1,x,x-min(l[x][y],d[x][y]));
scan[++tot]=node(0,x,0);
scan[++tot]=node(1,x+min(u[x][y],r[x][y]),x);
}
sort(scan+1,scan+1+tot);
bit.clr();
for(int i=1;i<=tot;i++)
{
switch(scan[i].k)
{
case -1:ans+=bit.sum(n)-bit.sum(scan[i].pos2-1);break;
case 0: bit.add(scan[i].pos1,1);break;
case 1: bit.add(scan[i].pos2,-1);break;
}
}
}
int main()
{
n=read()+1;m=read();
for(int i=1;i<=m;i++)
{
int x1=read()+1,y1=read()+1,x2=read()+1,y2=read()+1;
if(x1==x2)
for(int j=y1+1;j<=y2;j++)
b[x1][j]=true;
if(y1==y2)
for(int j=x1+1;j<=x2;j++)
a[j][y1]=true;
}
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
{
if(b[i][j]) d[i][j]=d[i][j-1]+1;
if(a[i][j]) l[i][j]=l[i-1][j]+1;
}
for(int i=n;i>=1;i--)
for(int j=n;j>=1;j--)
{
if(b[i][j+1]) u[i][j]=u[i][j+1]+1;
if(a[i+1][j]) r[i][j]=r[i+1][j]+1;
}
solve(1,1);
for(int i=2;i<=n;i++) solve(1,i),solve(i,1);
cout<<ans<<endl;
return 0;
}