思路是kuangbin大神博客里淘来的
其实题目是给了n条线段。问随机取三个,可以组成三角形的概率。
其实就是要求n条线段,选3条组成三角形的选法有多少种。
首先题目给了a数组,
如样例一:
4
1 3 3 4
把这个数组转化成num数组,num[i]表示长度为i的有num[i]条。
样例一就是
num = {0 1 0 2 1}
代表长度0的有0根,长度为1的有1根,长度为2的有0根,长度为3的有两根,长度为4的有1根。
使用FFT解决的问题就是num数组和num数组卷积。
num数组和num数组卷积的解决,其实就是从{1 3 3 4}取一个数,从{1 3 3 4}再取一个数,他们的和每个值各有多少个
例如{0 1 0 2 1}*{0 1 0 2 1} 卷积的结果应该是{0 0 1 0 4 2 4 4 1 }
长度为n的数组和长度为m的数组卷积,结果是长度为n+m-1的数组。
{0 1 0 2 1}*{0 1 0 2 1} 卷积的结果应该是{0 0 1 0 4 2 4 4 1 }。
这个结果的意义如下:
从{1 3 3 4}取一个数,从{1 3 3 4}再取一个数
取两个数和为 2 的取法是一种:1+1
和为 4 的取法有四种:1+3, 1+3 ,3+1 ,3+1
和为 5 的取法有两种:1+4 ,4+1;
和为 6的取法有四种:3+3,3+3,3+3,3+3,3+3
和为 7 的取法有四种: 3+4,3+4,4+3,4+3
和为 8 的取法有 一种:4+4
利用FFT可以快速求取循环卷积,具体求解过程不解释了,就是DFT和FFT的基本理论了。
总之FFT就是快速求到了num和num卷积的结果。只要长度满足>=n+m+1.那么就可以用循环卷积得到线性卷积了。
弄完FFT得到一个num数组,这个数组的含义在上面解释过了。
while( len < 2*len1 )len <<= 1; for(int i = 0;i < len1;i++) x1[i] = complex(num[i],0); for(int i = len1;i < len;i++) x1[i] = complex(0,0); fft(x1,len,1); for(int i = 0;i < len;i++) x1[i] = x1[i]*x1[i]; fft(x1,len,-1); for(int i = 0;i < len;i++) num[i] = (longlong)(x1[i].r+0.5);
这里代码中的num数组就是卷积后的结果,表示两两组合。
但是题目中本身和本身组合是不行的,所有把取同一个的组合的情况删掉。
//减掉取两个相同的组合for(int i = 0;i < n;i++) num[a[i]+a[i]]--;
还有,这个问题求组合,所以第一个选t1,第二个选t2,和第一个选t2,第二个选t1,我们认为是一样的。
所有num数组整体除于2
//选择的无序,除以2for(int i = 1;i <= len;i++) { num[i]/=2; }
然后对num数组求前缀和
sum[0] = 0; for(int i = 1;i <= len;i++) sum[i] = sum[i-1]+num[i];
之后就开始O(n)找可以形成三角形的组合了。
a数组从小到大排好序。
对于a[i]. 我们假设a[i]是形成的三角形中最长的。这样就是在其余中选择两个和>a[i],而且长度不能大于a[i]的。(注意这里所谓的大于小于,不是 说长度的大于小于,其实是排好序以后的,位置关系,这样就可以不用管长度相等的情况,排在a[i]前的就是小于的,后面的就是大于的)。
根据前面求得的结果。
长度和大于a[i]的取两个的取法是sum[len]-sum[a[i]].
但是这里面有不符合的。
一个是包含了取一大一小的
cnt -= (long long)(n-1-i)*i;
一个是包含了取一个本身i,然后取其它的
cnt -= (n-1);
还有就是取两个都大于的了
cnt -= (long long)(n-1-i)*(n-i-2)/2;
这样把i从0~n-1累加,就答案了。
longlong cnt = 0; for(int i = 0;i < n;i++) { cnt += sum[len]-sum[a[i]]; //减掉一个取大,一个取小的 cnt -= (longlong)(n-1-i)*i; //减掉一个取本身,另外一个取其它 cnt -= (n-1); //减掉大于它的取两个的组合 cnt -= (longlong)(n-1-i)*(n-i-2)/2; }#include <iostream> #include <string.h> #include <stdio.h> #include <math.h> #include<algorithm> using namespace std; const int N = 500005; const double PI = acos(-1.0); int n; struct Virt { double r, i; Virt(double r = 0.0,double i = 0.0) { this->r = r; this->i = i; } Virt operator + (const Virt &x) { return Virt(r + x.r, i + x.i); } Virt operator - (const Virt &x) { return Virt(r - x.r, i - x.i); } Virt operator * (const Virt &x) { return Virt(r * x.r - i * x.i, i * x.r + r * x.i); } }; //雷德算法--倒位序 void Rader(Virt F[], int len) { int j = len >> 1; for(int i=1; i<len-1; i++) { if(i < j) swap(F[i], F[j]); int k = len >> 1; while(j >= k) { j -= k; k >>= 1; } if(j < k) j += k; } } //FFT实现 void FFT(Virt F[], int len, int on) { Rader(F, len); for(int h=2; h<=len; h<<=1) //分治后计算长度为h的DFT { Virt wn(cos(-on*2*PI/h), sin(-on*2*PI/h)); //单位复根e^(2*PI/m)用欧拉公式展开 for(int j=0; j<len; j+=h) { Virt w(1,0); //旋转因子 for(int k=j; k<j+h/2; k++) { Virt u = F[k]; Virt t = w * F[k + h / 2]; F[k] = u + t; //蝴蝶合并操作 F[k + h / 2] = u - t; w = w * wn; //更新旋转因子 } } } if(on == -1) for(int i=0; i<len; i++) F[i].r /= len; } //求卷积 void Conv(Virt a[],int len) { FFT(a,len,1); // FFT(b,len,1); for(int i=0; i<len; i++) a[i] = a[i]*a[i]; FFT(a,len,-1); } Virt va[N],vb[N]; long long result[N]; long long len; long long sum[N]; long long a[N]; long long num[N]; void Init() { memset(num,0,sizeof(num)); memset(sum,0,sizeof(sum)); for(int i=0;i<n;i++){ cin>>a[i]; num[a[i]]++; } sort(a,a+n); int len1=a[n-1]+1; len = 1; while(len < 2*len1) len<<= 1; for(int i=0; i<len1; i++) { double v=num[i]; va[i].r = v; va[i].i = 0.0; } for(int i=len1;i<len;i++){ va[i].r = va[i].i = 0.0; } } void Work() { Conv(va,len); for(int i=0; i<len; i++) result[i] =(long long)(va[i].r+0.5); } int main() { int T; scanf("%d",&T); while(T--) { scanf("%d",&n); Init(); Work(); //将两条边的长度去重 for(int i=0;i<n;i++){ result[(a[i]+a[i])]--; } len=a[n-1]*2; for(int i=1;i<=len;i++){ result[i]/=2; } for(int i=1;i<=len;i++){ sum[i]=sum[i-1]+result[i]; } long long cnt=0; for(int i=0;i<n;i++){ cnt+=sum[len]-sum[a[i]]; cnt-=(long long)(n-1);//去掉选自身的那种 cnt-=(long long)(n-i-1)*(n-i-2)/2; cnt-=(long long)i*(long long)(n-i-1); } long long tot = (long long)n*(n-1)*(n-2)/6; printf("%.7lf\n",(double)cnt/tot); } return 0; }