题目描述
FFT
首先可以看做第二个+c,这个c可以为负数。
把第二个倍长。
拆式子容易发现。
需要求出
∑n−1i=0∑n−1j=0a[i]∗b[i+j]
的最小值
求出这个剩余部分是关于c的二次函数,用初中数学知识求解。
这个玩意怎么求》考虑把b翻转。
设
c[2n−j]=∑n−1i=0∑n−1j=0a[i]∗b[2n−i−j]
标准卷积形式!
#include<cstdio>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
typedef long long ll;
typedef double db;
const db pi=acos(-1);
const int maxn=500000+10,maxlen=2097152+10;
struct node{
db x,y;
friend node operator +(node a,node b){
node c;
c.x=a.x+b.x;c.y=a.y+b.y;
return c;
}
friend node operator -(node a,node b){
node c;
c.x=a.x-b.x;c.y=a.y-b.y;
return c;
}
friend node operator *(node a,node b){
node c;
c.x=a.x*b.x-a.y*b.y;c.y=a.x*b.y+a.y*b.x;
return c;
}
};
ll a[maxn],b[maxn*2];
node A[maxlen],B[maxlen],tt[maxlen]/*,w[maxlen]*/;
int rev[maxlen];
int i,j,k,l,t,n,m,c,len;
ll ans,mi,aa,bb,cc;
db zy,ce;
int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
void prepare(){
fo(i,0,len-1){
int p=0;
for (int j=0,tp=i;j<ce;j++,tp/=2) p=(p<<1)+(tp%2);
rev[i]=p;
}
/*w[0].x=1;w[0].y=0;
w[1].x=cos(2*pi/len);w[1].y=sin(2*pi/len);
fo(i,2,len) w[i]=w[i-1]*w[1];*/
}
void DFT(node *a,int sig){
int i;
fo(i,0,len-1) tt[rev[i]]=a[i];
for (int m=2;m<=len;m*=2){
int half=m/2,bei=len/m;
fo(i,0,half-1){
//node wi=sig>0?w[i*bei]:w[len-i*bei];
node wi;
wi.x=cos(i*pi*sig/half);wi.y=sin(i*pi*sig/half);
for (int j=i;j<len;j+=m){
node u=tt[j],v=tt[j+half]*wi;
tt[j]=u+v;
tt[j+half]=u-v;
}
}
}
if (sig==-1)
fo(i,0,len-1) tt[i].x/=len;
fo(i,0,len-1) a[i]=tt[i];
}
void FFT(node *a,node *b){
DFT(a,1);DFT(b,1);
fo(i,0,len-1) a[i]=a[i]*b[i];
DFT(a,-1);
}
int main(){
freopen("gift.in","r",stdin);freopen("gift.out","w",stdout);
n=read();m=read();
fo(i,0,n-1) a[i]=read();
fo(i,0,n-1) b[i]=b[i+n]=read();
reverse(b,b+2*n);
len=1;
while (len<=4*n-2) len*=2;
ce=log(len)/log(2);
prepare();
fo(i,0,n-1) A[i].x=a[i];
fo(i,0,2*n-1) B[i].x=b[i];
FFT(A,B);
mi=round(A[2*n].x);
fo(j,1,n) mi=max(mi,ll(round(A[2*n-j].x)));
mi*=2;
fo(i,0,n-1){
cc+=a[i]*a[i];
cc+=b[i]*b[i];
bb-=2*a[i];
bb+=2*b[i];
}
cc-=mi;
aa=n;
zy=-((db)bb/(2*aa));
c=round(zy);
ans=aa*c*c+bb*c+cc;
printf("%lld\n",ans);
}