Brief Description:
一个01方阵中找出四条边全都是1的正方形的个数,对于正方形内部则没有要求。
Analysis:
摘自题解:
一个直观的想法是首先用N^2的时间预处理出每一个是1的点向上下左右四个方向能够延伸的1的最大长度,记为四个数组l, r, u, d。然后我们观察到正方形有一个特征是同一对角线上的两个顶点在原方阵的同一条对角线上。于是我们可以想到枚举原来方阵的每条对角线,然后我们对于每条对角线枚举对角线上所有是1的点i,那么我们可以发现可能和i构成正方形的点应该在该对角线的 [i, i + min(r[i], d[i]) – 1] 闭区间内, 而在这个区间内的点 j 只要满足 j – i + 1 <= min(l[j], u[j]) 也就是满足j – min(l[j], u[j]) + 1 <= i,这样的 (i, j) 就能构成一个正方形。也就是说对于每条对角线,我们可以构造一个数组 a, 使得a[i] = i – min(l[i], u[i]) + 1 然后对这个数组有若干次查询,每次查询的是区间 [i, i + min(r[i], d[i]) – 1]内有多少个数满足 a[j] <= i,所有这些问题答案的和就是该问题的结果。对于这个问题,我们可以通过离线算法,先保存所有查询的区间端点,并对所有端点排序。然后使用扫描线算法,如果扫描到的是第i次查询的左端点,就让当前结果减去当前扫描过的数中 <= i的个数,如果扫描到的是第i次查询的有短点,则让当前结果加上当前扫描过的数中 <= i的个数,最后所有结果相加即可。 维护当前数出现的个数可以使用树状数组。这样对于每条对角线求结果的复杂度为O(nlogn),算法总的复杂度为O(n^2logn)。具体见以下代码:(注意扫描线的细节)#include <iostream> #include <cstdio> #include <cstdlib> #include <cstring> #include <cmath> #include <algorithm> using namespace std; const int maxn = 1010; struct node{ int id,num,flag; //id为当前下标,num为属于第几个区间,flag=0为左端点,flag=1为右端点 node(){} node(int _id,int _num,int _flag){ id = _id; num = _num; flag = _flag; } }cc[maxn*2]; // remember*2!! int n,ta=1; int a[maxn][maxn]; int l[maxn][maxn],r[maxn][maxn],u[maxn][maxn],d[maxn][maxn],aa[maxn]; int ans,c[maxn]; bool cmp(const node &p,const node &q){ if(p.id != q.id) return p.id < q.id; return p.flag < q.flag; } inline int lowbit(int x){ return x & (-x); } // 单点加1 void add(int x){ while(x <= n){ c[x] += 1; x += lowbit(x); } } // 区间求和 int Sum(int x){ int s = 0; while(x){ s += c[x]; x -= lowbit(x); } return s; } void init(){ int i,j; scanf("%d",&n); memset(a,0,sizeof(a)); for(i=1; i<=n; i++) for(j=1; j<=n; j++){ scanf("%d",&a[i][j]); // ans += a[i][j]; } for(i=1; i<=n; i++) for(j=1; j<=n; j++){ u[i][j] = a[i][j]?u[i-1][j]+1:0; l[i][j] = a[i][j]?l[i][j-1]+1:0; } for(i=1; i<=n; i++) d[n+1][i] = r[i][n+1] = 0; for(i=n; i>=1; i--) for(j=n; j>=1; j--){ d[i][j] = a[i][j]?d[i+1][j]+1:0; r[i][j] = a[i][j]?r[i][j+1]+1:0; } } int ccSum; void Work(){ int i,j,k; ans = 0; for(k=n-1; k>=-n+1; k--){ int st = k>=0?1:1-k; int ed = k>=0?n-k:n; // 枚举每个对角线 memset(c,0,sizeof(c)); memset(aa,0,sizeof(aa)); ccSum = 0; for(i=st; i<=ed; i++){ if(a[i+k][i]){ cc[ccSum++] = node(i,i,0); // 左端点 cc[ccSum++] = node(i+min(r[i+k][i],d[i+k][i])-1,i,1); // 右端点 aa[i] = i - min(l[i+k][i],u[i+k][i]) + 1; } } sort(cc,cc+ccSum,cmp); j = 0; // 扫描线 for(i=0; i<ccSum; i++){ if(cc[i].flag == 0){ while(j < cc[i].id){ if(aa[j]) add(aa[j]); j++; } ans -= Sum(cc[i].num-1); }else{ while(j <= cc[i].id){ if(aa[j]) add(aa[j]); j++; } ans += Sum(cc[i].num); } } } printf("Case %d: %d\n",ta++,ans); } int main() { int cas; scanf("%d",&cas); for(int i=0; i<maxn; i++) u[0][i] = l[i][0] = 0; while(cas--){ init(); Work(); } return 0; }