题目描述
Pty继续着他的疯狂奔跑,终于渐渐体力不支,在一个应该拐弯的地方没有刹住车,掉入了深深的沼泽中,“啊~~~~~·”pty惊恐的大叫,突然从梦中惊醒了。哪里还有什么奇怪的金字塔,沼泽地,大树。。。只是一个梦而已呀。看了看自己熟悉的房间,pty定了定神。
好不容易恢复了过来,pty突然想到还有集训队的互测题没有出!!,如果没有出完的话,后果= =。。啧啧。。pty宁愿再回到金字塔去。于是pty想啊想,找啊找,找到了一道傻逼题:
给定一个矩阵A:n行m列,一个矩阵B:h行w列,在B矩阵中有一个特殊的位置为(x,y)。现在可以从A矩阵中选出一个大小和B相等的区域,设选出的矩阵为C,那么花费的代价是
∑hi=1∑wj=1(C[i,j]−C[x,y]−B[i,j])2
现在pty想知道在A矩阵中选出的所有C矩阵中前K小的代价分别是多少。
化简
我们假设已经选出了矩阵C,那么我们来化简式子,我们可以拆成六项:
c[i,j]2,−2∗c[i,j]∗c[x,y],−2∗c[i,j]∗b[i,j],c[x,y]2,2∗c[x,y]∗b[i,j],b[i,j]2
对于第五项和第六项我们可以求出B的所有元素和与平方和。
第一项要求得到子矩阵平方和,第二项要求得到子矩阵和,我们可以维护二维前缀(平方)和。
第四项很容易得到。
关键在于第三项,这该如何突破呢?
往FFT方面进军
把原本所有矩阵下标减一来方便运算,即所有矩阵以(0,0)位左上角而不是(1,1)。
我们设ans[i,j]表示以(i,j)为左上角的h*w的矩阵的第三项代价,那么
ans[i,j]=∑h−1i′=0∑w−1j′=0a[i+i′][j+j′]∗b[i′,j′]
我们令b矩阵进行如下映射b[i,j]=b[h-i-1,h-j-1]
那么
ans[i,j]=∑h−1i′=0∑w−1j′=0a[i+i′][j+j′]∗b[h−i′−1,w−j′−1]
令
A[i∗m+j]=a[i,j],B[i∗m+j]=b[i,j]
ans[i,j]=∑h−1i′=0∑w−1j′=0A[(i+i′)∗m+j+j′]∗B[(h−i′−1)∗m+w−j′−1]
注意到
(i+i′)∗m+j+j′+(h−i′−1)∗m+w−j′−1=i∗m+j+hm+w−m−1
那么我们可以得到标准卷积形式,令C=A*B
则ans[i,j]=C[i*m+j+hm+w-m-1]
有了ans,算出其他东西,排个序就好。
参考程序
#include<cstdio>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
typedef double db;
struct node{
db x,y;
node friend operator +(node a,node b){
node c;
c.x=a.x+b.x;c.y=a.y+b.y;
return c;
}
node friend operator -(node a,node b){
node c;
c.x=a.x-b.x;c.y=a.y-b.y;
return c;
}
node friend operator *(node a,node b){
node c;
c.x=a.x*b.x-a.y*b.y;c.y=a.x*b.y+a.y*b.x;
return c;
}
};
struct dong{
int x,y,z;
bool friend operator <(dong a,dong b){
if (a.z<b.z) return 1;
else if (a.z==b.z&&a.x<b.x) return 1;
else if (a.z==b.z&&a.x==b.x&&a.y<b.y) return 1;
else return 0;
}
};
const int maxl=666*666*5;
const db pi=acos(-1);
db ce;
node A[maxl],B[maxl],C[maxl],e[maxl],f[maxl],tt[maxl],wt;
dong d[666*666*2];
int a[666+10][666+10],b[666+10][666+10],ans[666+10][666+10],sum[666+10][666+10],num[666+10][666+10];
int i,j,k,l,t,n,m,h,w,xx,yy,len,top,ll,ttt;
void DFT(node *a,int sig){
fo(i,0,len-1){
int p=0;
for(int j=0,tp=i;j<ce;j++,tp/=2) p=(p<<1)+(tp%2);
tt[p]=a[i];
}
for(int m=2;m<=len;m*=2){
int half=m/2;
fo(i,0,half-1){
node w;
w.x=cos(i*sig*pi/half),w.y=sin(i*sig*pi/half);
for(int j=i;j<len;j+=m){
node u=tt[j],v=tt[j+half]*w;
tt[j]=u+v;
tt[j+half]=u-v;
}
}
}
if (sig==-1)
fo(i,0,len-1) tt[i].x/=len;
fo(i,0,len-1) a[i]=tt[i];
}
void FFT(node *a,node *b,node *c){
int i;
fo(i,0,len-1) e[i]=a[i],f[i]=b[i];
DFT(e,1);DFT(f,1);
fo(i,0,len-1) e[i]=e[i]*f[i];
DFT(e,-1);
fo(i,0,len-1) c[i]=e[i];
}
int getsqr(int x,int y){
if (!x&&!y) return sum[x+h-1][y+w-1];
else if (x&&!y) return sum[x+h-1][y+w-1]-sum[x-1][y+w-1];
else if (!x&&y) return sum[x+h-1][y+w-1]-sum[x+h-1][y-1];
else return sum[x+h-1][y+w-1]-sum[x-1][y+w-1]-sum[x+h-1][y-1]+sum[x-1][y-1];
}
int get(int x,int y){
if (!x&&!y) return num[x+h-1][y+w-1];
else if (x&&!y) return num[x+h-1][y+w-1]-num[x-1][y+w-1];
else if (!x&&y) return num[x+h-1][y+w-1]-num[x+h-1][y-1];
else return num[x+h-1][y+w-1]-num[x-1][y+w-1]-num[x+h-1][y-1]+num[x-1][y-1];
}
int main(){
scanf("%d%d",&n,&m);
fo(i,0,n-1)
fo(j,0,m-1){
scanf("%d",&a[i][j]);
if (!i&&!j) sum[i][j]=a[i][j]*a[i][j],num[i][j]=a[i][j];
else if (i&&!j) sum[i][j]=sum[i-1][j]+a[i][j]*a[i][j],num[i][j]=num[i-1][j]+a[i][j];
else if (!i&&j) sum[i][j]=sum[i][j-1]+a[i][j]*a[i][j],num[i][j]=num[i][j-1]+a[i][j];
else sum[i][j]=sum[i-1][j]+sum[i][j-1]-sum[i-1][j-1]+a[i][j]*a[i][j],num[i][j]=num[i-1][j]+num[i][j-1]-num[i-1][j-1]+a[i][j];
}
scanf("%d%d",&h,&w);
fo(i,0,h-1)
fo(j,0,w-1){
scanf("%d",&b[i][j]);
ll+=b[i][j]*b[i][j];ttt+=b[i][j];
}
scanf("%d%d",&xx,&yy);
xx--,yy--;
fo(i,0,n-1)
fo(j,0,m-1)
A[i*m+j].x=a[i][j];
fo(i,0,h-1)
fo(j,0,w-1)
B[i*m+j].x=b[h-i-1][w-j-1];
len=1;
while (len<n*m*2) len*=2;
ce=db(log(len)/log(2));
FFT(A,B,C);
fo(i,0,n-h)
fo(j,0,m-w){
d[++top].z=ans[i][j]=-2*round(C[i*m+j+h*m+w-m-1].x)+getsqr(i,j)+ll+2*a[i+xx][j+yy]*ttt-2*a[i+xx][j+yy]*get(i,j)+a[i+xx][j+yy]*a[i+xx][j+yy]*h*w;
d[top].x=i;d[top].y=j;
}
sort(d+1,d+top+1);
scanf("%d",&k);
fo(i,1,k) printf("%d %d %d\n",d[i].x+1,d[i].y+1,d[i].z);
}