数字信号处理课逼我学FFT呀
#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=a;i>=b;i--)
#define what_is(x) cerr<<#x<<" is "<<x<<endl;
const double PI=acos(-1.0);
struct Complex{
double x,y;
Complex(double _x=0.0,double _y=0.0){
x=_x;
y=_y;
}
Complex operator+(const Complex &b)const{
return Complex(x+b.x,y+b.y);
}
Complex operator-(const Complex &b)const{
return Complex(x-b.x,y-b.y);
}
Complex operator*(const Complex &b)const{
return Complex(x*b.x-y*b.y,x*b.y+y*b.x);
}
};
void change(Complex y[],int len){
int i,j,k;
for(i=1,j=len/2;i<len-1;i++){
if(i<j)swap(y[i],y[j]);
k=len/2;
while(j>=k){
j-=k;
k/=2;
}
if(j<k)j+=k;
}
}
void fft(Complex y[],int len,int on){
change(y,len);
for(int h=2;h<=len;h<<=1){
Complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
for(int j=0;j<len;j+=h){
Complex w(1,0);
for(int k=j;k<j+h/2;k++){
Complex u=y[k];
Complex t=w*y[k+h/2];
y[k]=u+t;
y[k+h/2]=u-t;
w=w*wn;
}
}
}
if(on==-1){
rep(i,0,len-1){
y[i].x/=len;
}
}
}
const int MAXN=2e5+10;
char str1[MAXN],str2[MAXN];
Complex x[MAXN],y[MAXN];
int ans[MAXN];
int main(){
while(~scanf("%s%s",str1,str2)){
int len1=strlen(str1);
int len2=strlen(str2);
int len=1;
for(;len<len1+len2;len<<=1);
rep(i,0,len1-1){
x[i]=Complex(str1[len1-1-i]-'0',0);
}
rep(i,len1,len-1){
x[i]=Complex(0,0);
}
rep(i,0,len2-1){
y[i]=Complex(str2[len2-1-i]-'0',0);
}
rep(i,len2,len-1){
y[i]=Complex(0,0);
}
fft(x,len,1);
fft(y,len,1);
rep(i,0,len-1){
x[i]=x[i]*y[i];
}
fft(x,len,-1);
ans[0]=0;
rep(i,0,len-1){
ans[i]+=(int)(x[i].x+0.5);
ans[i+1]=ans[i]/10;
ans[i]%=10;
}
while(ans[len]==0&&len!=0)len--;
per(i,len,0)putchar(char('0'+ans[i]));
puts("");
}
return 0;
}
题意
#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define what_is(x) cerr<<#x<<" is "<<x<<endl;
typedef long long ll;
typedef double ld;
const int MAXN=5e5+233;
const ld PI=-acos(-1.0);
struct cp{
ld x,y;
cp(ld _x=0.0,ld _y=0.0){
x=_x;
y=_y;
}
cp operator+(const cp &b)const{
return cp(x+b.x,y+b.y);
}
cp operator-(const cp &b)const{
return cp(x-b.x,y-b.y);
}
cp operator*(const cp &b)const{
return cp(x*b.x-y*b.y,x*b.y+y*b.x);
}
};
void change(cp y[],int len){
int i,j,k;
for(i=1,j=len/2;i<len-1;i++){
if(i<j)swap(y[i],y[j]);
k=len/2;
while(j>=k){
j-=k;
k/=2;
}
if(j<k)j+=k;
}
}
void fft(cp y[],int len,int on){
change(y,len);
for(int h=2;h<=len;h<<=1){
cp wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
for(int j=0;j<len;j+=h){
cp w(1,0);
for(int k=j;k<j+h/2;k++){
cp u=y[k];
cp t=w*y[k+h/2];
y[k]=u+t;
y[k+h/2]=u-t;
w=w*wn;
}
}
}
if(on==-1){
rep(i,0,len-1){
y[i].x/=len;
}
}
}
int a[MAXN];
ll cnt[MAXN];
cp x[MAXN];
int main(){
int T;
scanf("%d",&T);
rep(kase,1,T){
int n;
scanf("%d",&n);
memset(cnt,0,sizeof(cnt));
rep(i,1,n){
scanf("%d",&a[i]);
cnt[a[i]]++;
}
sort(a+1,a+1+n);
int m=2*a[n];
int len=1;
for(;len<m+2;len<<=1);
rep(i,0,m){
x[i]=cp(cnt[i],0);
}
rep(i,m+1,len-1){
x[i]=cp(0,0);
}
fft(x,len,1);
rep(i,0,len-1){
x[i]=x[i]*x[i];
}
fft(x,len,-1);
rep(i,0,m){
cnt[i]=(ll)(x[i].x+0.5);
}
rep(i,1,n){
cnt[a[i]+a[i]]--;
}
rep(i,1,m){
cnt[i]/=2;
cnt[i]+=cnt[i-1];
}
ll w=0;
rep(i,1,n){
w+=cnt[m]-cnt[a[i]];
w-=1LL*(i-1)*(n-i);
w-=1LL*(n-i)*(n-i-1)/2;
w-=n-1;
}
ll tot=1LL*n*(n-1)*(n-2)/6;
printf("%.7lf\n",(ld)(w)/tot);
}
return 0;
}
#include<bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(int i=a;i<=b;i++)
typedef long long ll;
const ll MOD=998244353;
const int MAXN=5e6+10;
const ll g=3;
ll fp(ll a,ll b){
if(b<0){
a=fp(a,MOD-2);
b=-b;
}
ll res=1;
while(b){
if(b&1)res=res*a%MOD;
a=a*a%MOD;
b>>=1;
}
return res;
}
void change(ll y[],int len){
int i,j,k;
for(i=1,j=len/2;i<len-1;i++){
if(i<j)swap(y[i],y[j]);
k=len/2;
while(j>=k){
j-=k;
k/=2;
}
if(j<k)j+=k;
}
}
void ntt(ll y[],int len,int on){
change(y,len);
for(int h=2;h<=len;h<<=1){
ll wn=fp(g,-on*(MOD-1)/h);
for(int j=0;j<len;j+=h){
ll w=1;
for(int k=j;k<j+h/2;k++){
ll u=y[k];
ll t=w*y[k+h/2]%MOD;
y[k]=(u+t)%MOD;
y[k+h/2]=(u-t+MOD)%MOD;
w=w*wn%MOD;
}
}
}
if(on==-1){
ll t=fp(len,-1);
rep(i,0,len-1){
y[i]=y[i]*t%MOD;
}
}
}
ll x[MAXN];
int main(){
int n,k;
scanf("%d%d",&n,&k);
rep(i,1,k){
int v;
scanf("%d",&v);
x[v]=1;
}
int len=1;
for(;len<10*n;len<<=1);
ntt(x,len,1);
rep(i,0,len-1){
x[i]=fp(x[i],n/2);
}
ntt(x,len,-1);
ll res=0;
rep(i,0,len-1){
res=(res+x[i]*x[i]%MOD)%MOD;
}
printf("%I64d",res);
return 0;
}