题目:两个序列,求一个差值平方和最小 min{(a1-b1)²+...+(an-bn)²,(a1-b2)²+...+(an-b1)²,...,(a1-bn)²+...+(an-b1)²}
思路:原式变形后就是sigma(a[i]^2)+sigma(b[i]^2)-2*sigma(a[i]*b[(i+k)%n])的最大值,也就是sigma(a[i]*b[(i+k)%n])的最小值。
这个可以用FFT来求了,可以用题目给的第二组简单的样例模拟一下,需要将b数组反转,最后还要注意精度(找到k后重新按式子跑一遍)。直接硬套fft模板了。。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const double PI=acos(-1.0);
struct Complex{
double x,y;
Complex(double _x=0.0,double _y=0.0){
x=_x;
y=_y;
}
Complex operator -(const Complex &b)const{
return Complex(x-b.x,y-b.y);
}
Complex operator +(const Complex &b)const{
return Complex(x+b.x,y+b.y);
}
Complex operator *(const Complex &b)const{
return Complex(x*b.x-y*b.y,x*b.y+y*b.x);
}
};
void change(Complex y[],int len)
{
int i,j,k;
for(i=1,j=len/2;i<len-1;i++)
{
if(i<j) swap(y[i],y[j]);
k=len/2;
while(j>=k)
{
j-=k;
k/=2;
}
if(j<k)j+=k;
}
}
void fft(Complex y[],int len,int on)
{
change(y,len);
for(int h=2;h<=len;h<<=1)
{
Complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
for(int j=0;j<len;j+=h)
{
Complex w(1,0);
for(int k=j;k<j+h/2;k++)
{
Complex u=y[k];
Complex t=w*y[k+h/2];
y[k]=u+t;
y[k+h/2]=u-t;
w=w*wn;
}
}
}
if(on==-1)
for(int i=0;i<len;i++)
y[i].x/=len;
}
const int MAXN=400010;
Complex x1[MAXN],x2[MAXN];
ll str1[MAXN/2],str2[MAXN/2];
ll sum[MAXN];
int t;
int main()
{
scanf("%d",&t);
while(t--)
{
int len1,len2;
scanf("%d",&len1);
len2=len1;
ll ans=0;
for(int i=0;i<len1;i++)
{
scanf("%lld",&str1[i]);
ans=ans+str1[i]*str1[i];
}
for(int i=0;i<len2;i++)
{
scanf("%lld",&str2[i]);
ans=ans+str2[i]*str2[i];
}
//for(int i=len2-1;i>=0;i--)
// cout<<str2[i]<<" ";
//cout<<endl;
int len=1;
while(len<len1*2||len<len2*2) len<<=1;
for(int i=0;i<len1;i++)
x1[i]=Complex(str1[i],0);
for(int i=len1;i<len;i++)
x1[i]=Complex(0,0);
for(int i=0;i<len2;i++)
x2[i]=Complex(str2[len2-1-i],0);
for(int i=len2;i<len;i++)
x2[i]=Complex(0,0);
fft(x1,len,1);
fft(x2,len,1);
for(int i=0;i<len;i++)
x1[i]=x1[i]*x2[i];
fft(x1,len,-1);
for(int i=0;i<len;i++)
sum[i]=(ll)(x1[i].x+0.5);
ll ma=0;
int k=0;
for(int i=0;i<len1;i++)
if(sum[i]+sum[i+len1]>ma)
{
ma=sum[i]+sum[i+len1];
k=len1-i-1;
}
//cout<<k<<endl;
ma=0;
for(int i=0;i<len1;i++)
ma+=str1[i]*str2[(i+k)%len1];
ans-=ma*2;
printf("%lld\n",ans);
}
return 0;
}