Problem
Solution
我们不妨设旋转了j,亮度增加了c,当然为了防止b数组越界,在后面复制一遍。
则我们的答案会变成
∑i=1n(a[i]−b[i+j]+c)2
∑
i
=
1
n
(
a
[
i
]
−
b
[
i
+
j
]
+
c
)
2
我们将其展开,就可以得到如下式子:
∑a[i]2+∑b[i]2−2∗∑(a[i]∗b[i+j])+∑c2+2∗(∑a[i]−∑b[i])∗c
∑
a
[
i
]
2
+
∑
b
[
i
]
2
−
2
∗
∑
(
a
[
i
]
∗
b
[
i
+
j
]
)
+
∑
c
2
+
2
∗
(
∑
a
[
i
]
−
∑
b
[
i
]
)
∗
c
当然,其中 ∑c2=n∗c2 ∑ c 2 = n ∗ c 2 。
容易发现,前两项是常数,而后面的两项式子其实就是与c有关的二次函数,n是正整数,所以求个最小值即可。
那么我们就只需要关心中间那一项什么时候最大。按照套路,我们将b翻转一下:
∑(a[i]∗b[i+j])=∑(a[i]∗b[n−i−j+1])
∑
(
a
[
i
]
∗
b
[
i
+
j
]
)
=
∑
(
a
[
i
]
∗
b
[
n
−
i
−
j
+
1
]
)
同样变成了卷积的形式。那么我们就可以利用FFT,然后就可以直接扫出最优的j并得到相应答案了。
时间复杂度 O(nlogn) O ( n l o g n ) 。
Code
#include <algorithm>
#include <cstdio>
#include <cmath>
using namespace std;
typedef long long ll;
const int maxn=50010;
const double pi=acos(-1.0);
struct cpx{
double r,i;
cpx(){}
cpx(double _r,double _i){r=_r;i=_i;}
cpx operator + (const cpx &x)const{return cpx(r+x.r,i+x.i);}
cpx operator - (const cpx &x)const{return cpx(r-x.r,i-x.i);}
cpx operator * (const cpx &x)const{return cpx(r*x.r-i*x.i,r*x.i+i*x.r);}
cpx operator *= (const cpx &x){return *this=*this*x;}
}a[maxn<<2],b[maxn<<2],c[maxn<<2];
int n,m,fr,k,l,r[maxn<<2];
ll sa,sb,dt,ans;
double maxx=0.0;
inline ll fac(double x){return floor(x+0.5)>floor(x)?ceil(x):floor(x);}
void input()
{
int x;
scanf("%d%d",&n,&m);fr=n;
for(int i=1;i<=n;i++)
{
scanf("%d",&x);sa+=x;ans+=(ll)x*x;
a[i]=cpx(1.0*x,0.0);
}
for(int i=n;i>=1;i--)
{
scanf("%d",&x);sb+=x;ans+=(ll)x*x;
b[i]=cpx(1.0*x,0.0);b[i+n]=b[i];
}
}
void fft(cpx *a,int f)
{
for(int i=0;i<n;i++)
if(i<r[i]) swap(a[i],a[r[i]]);
for(int i=1;i<n;i<<=1)
{
cpx wn(cos(pi/i),f*sin(pi/i));
for(int j=0;j<n;j+=(i<<1))
{
cpx w(1,0);
for(int k=0;k<i;k++,w*=wn)
{
cpx x=a[j+k],y=w*a[j+k+i];
a[j+k]=x+y;a[j+k+i]=x-y;
}
}
}
if(f==-1)
for(int i=0;i<n;i++) a[i].r/=n;
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("in.txt","r",stdin);
#endif
input();
k=fac(1.0*(sb-sa)/n);dt=(ll)n*k*k+2ll*(sa-sb)*k;
m=n<<1;
for(n=1;n<=m;n<<=1) l++;
for(int i=0;i<n;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
fft(a,1);fft(b,1);
for(int i=0;i<n;i++) c[i]=a[i]*b[i];
fft(c,-1);
for(int i=fr+1;i<=fr*2;i++)
if(c[i].r>maxx)
maxx=c[i].r;
ans+=dt-fac(maxx*2.0);
printf("%lld\n",ans);
return 0;
}