大致题意
给n个数ai,问不重复的选取3个数能组成的所有的数各有多少种。(n<=40000,|ai|<=20000);
思路
不考虑重复当然直接 A(x)*A(x)*A(x),然后可能有两个相同加一个不同的情况
A(x^2)B(x),排列数是3,然后3个都相同C(x ^3)排列数是1。所以容斥一下,总的剪掉2个元素相同情况的时候,多剪了2个 (3个元素相同的情况)
最总答案是ans= [ A(x) * A(x) * A(x)-3A(x ^2)*B(x) + C(x ^3)]/6;
代码
再贴一下匡斌的板子
#include<bits/stdc++.h>
using namespace std;
#define maxn 300005
#define maxm 1000006
#define ll long long int
#define INF 0x3f3f3f3f
#define inc(i,l,r) for(int i=l;i<=r;i++)
#define dec(i,r,l) for(int i=r;i>=l;i--)
#define mem(a) memset(a,0,sizeof(a))
#define sqr(x) (x*x)
#define inf (ll)2e18+1
#define PI acos(-1)
int read(){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-'0',ch=getchar();
return f*x;
}
struct Complex
{
double x, y; // 实部和虚部 x + yi
Complex(double _x = 0.0, double _y = 0.0){x = _x;y = _y;}
Complex operator - (const Complex &b) const{return Complex(x - b.x, y - b.y);}
Complex operator + (const Complex &b) const{return Complex(x + b.x, y + b.y);}
Complex operator * (const Complex &b) const{return Complex(x * b.x - y * b.y, x * b.y + y * b.x);}
};
void change(Complex y[], int len)
{
int i,j,k;
for(i=1,j=len/2;i<len-1;i++){
if(i<j)swap(y[i],y[j]);
k=len/2;
while (j>=k){j-=k;k/=2;}
if (j<k)j+=k;
}
return ;
}
void fft(Complex y[], int len, int on)
{
change(y,len);
for(int h=2;h<=len;h<<=1){
Complex wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
for(int j=0;j<len;j+=h){
Complex w(1,0);
for(int k=j;k<j+h/2;k++){
Complex u=y[k];
Complex t=w*y[k+h/2];
y[k]=u+t;
y[k+h/2]=u-t;
w=w*wn;
}
}
}
if(on==-1)for(int i=0;i<len;i++)y[i].x/=len;
}
ll cnt1[maxn],cnt2[maxn],cnt3[maxn];
int n;
Complex v1[maxn],v2[maxn];
int main()
{
n=read();
int x,ma=-1;
inc(i,1,n){
x=read();
x+=20000;
ma=max(ma,x);
cnt1[x]++;cnt2[x*2]++;cnt3[x*3]++;
}
int len=1;
while(len<6*ma)len<<=1;
inc(i,0,ma)v1[i]=Complex(cnt1[i],0);
inc(i,ma+1,len-1)v1[i]=Complex(0,0);
inc(i,0,ma*2)v2[i]=Complex(cnt2[i],0);
inc(i,ma*2+1,len-1)v2[i]=Complex(0,0);
fft(v1,len,1);fft(v2,len,1);
inc(i,0,len-1)v1[i]=v1[i]*(v1[i]*v1[i]-Complex(3.0,0.0)*v2[i]);
fft(v1,len,-1);
inc(i,0,len-1)cnt1[i]=(ll)(v1[i].x+0.5);
inc(i,0,len-1)cnt1[i]=(cnt1[i]+2*cnt3[i])/6;
inc(i,0,len-1)if(cnt1[i]!=0)printf("%d : %lld\n",i-60000,cnt1[i]);
return 0;
}