FFT+容斥原理
Hdu 4093 ( Xavier is Learning to Count )
题意:
给定一堆牌,牌上有两两不同的数字,要求从中选出p张牌(1<=p<=5),使得这p张牌两两不同,且从大到小排列之后它们的数字之和为n,然后求出对于每个n(1<=n
讲解:
首先我们不考虑选出牌的顺序以及每次选出的p张牌可以重复,那么这就是一个组合问题,选出p张牌使得数字之和为一固定值,这里就是使用生成函数来求解,对于每个牌的数字有一个可选集合,用多项式来表示就是其系数为选出这张牌的方案数,所以该多项式就是在x^a[i]的地方系数为1,其他地方为0。
比如我们得到的多项式是F[1]=Σx^(a[i])
其次我们还是忽略掉选出的牌有序的要求,而把每张牌不重复的要求考虑上,这里要怎么把重复的选牌方案删掉呢,我们考虑用容斥原理来计算,我们知道容斥原理的一般形式是求出不具有任一一个性质的集合的数目,即|~A^~B^~C|,这里我们要计算的就是任意两个位置的牌的集合的大小。
所以这个集合可以表示为|^~AiAj(1《=i《j《=p)|
这里AiAj代表第i张牌和第j张牌数字相同,即选择了同一张牌,~AiAj代表第i张牌和第j张牌数字不同,^~AiAj(1《=i《j《=p)即为所求的集合,任意两张牌都不相同,
然后用容斥原理来计算,这个是标准形式,
只要枚举A1A2,A1A3,…A1Ap
A2A3,…,A2Ap
……..
Ap-1Ap
这些集合的任意组合,计算一下系数就可以了
所以现在就是要考虑有若干个位置的牌数字相同怎么求,还是生成函数,就是把这些牌绑定到一起考虑,
这里以两张牌数字相同为例,就是说他们必须选择相同的牌,所以他们数字之和是一张牌的两倍,但是这样选择方案还是只有一种,所以这个生成多项式就是F[2]=Σx^(2*a[i])
同理p张牌数字相同就是F[p]=Σx^(p*a[i])
然后用容斥原理组合起来就可以了
然后我判断那两些数是相同的用了一个并查集,比较naive
其实是看不懂大神的代码,只好用了一个复杂度高的写法了,
这里我的代码参考了大神BZOJ2498 : Xavier is Learning to Count
代码:
#include<cstring>
#include<algorithm>
#include<complex>
#include<iostream>
#define debug //
using namespace std;
const double PI=acos(-1.0);
typedef complex<double> CD;
const int maxn=100000;
void BitReverse(CD a[],int n)
{
for(int i=0,j=0;i<n;i++){
if(j>i) swap(a[i],a[j]);
// debug("%2x %2x\n",i,j);
int k=n;
while(j&(k>>=1)) j&=~k;
j|=k;
}
}
void FFT(CD a[],int n,bool reverse)
{
BitReverse(a,n);
double pi=reverse?-PI:PI;//这里PI和-PI可以交换顺序
for(int step=1;step<n;step<<=1){
//这里是觉得每步两个子问题的规模,所以原问题的规模为step<<1
//然后每次合并需要进行step操作
//同样的这里sin里面可以带上一个负号
CD wn=CD(cos(2*pi/(step<<1)),sin((2*pi/(step<<1))));
// 这里可以写成double alpha=pi/step;
for(int k=0;k<step;k++){
//这里的循环和内层循环可以交换顺序
//这里保证wnk正确即可,所以可以预处理出wnk来,不过比较耗费空间
CD wnk=CD(cos(2*pi*k/(step<<1)),sin(2*pi*k/(step<<1)));
//可以优化为wnk=CD(cos(pi*k/step),sin(pi*k/step));
//然后这里可以又alpha来得到wnk=exp(CD(alpha*k));
for(int i=k;i<n;i+=(step<<1)){
int j=i+step;
CD tmp=a[j]*wnk;
a[j]=a[i]-tmp;
a[i]=a[i]+tmp;
}
}
}
if(reverse){
for(int i=0;i<n;i++){
a[i]/=n;
}
}
}
int juanji(double A[],double B[],double C[],int M1,int M2)
{
int len=M1+M2-1;
int n=1;
while(n<len) n<<=1;
static CD X[maxn],Y[maxn];
for(int i=0;i<M1;i++){
X[i]=CD(A[i],0.0);
}
for(int i=M1;i<n;i++){
X[i]=CD(0.0,0.0);
}
FFT(X,n,false);
for(int i=0;i<M1;i++){
Y[i]=CD(B[i],0.0);
}
for(int i=M1;i<n;i++){
Y[i]=CD(0.0,0.0);
}
FFT(Y,n,false);
for(int i=0;i<n;i++){
X[i]*=Y[i];
}
FFT(X,n,true);
for(int i=0;i<n;i++){
C[i]=X[i].real();
}
return n;
}
int p;
int m;
int a[maxn];
int T;
int kase;
int n;
const int maxm=7;
CD f[maxm][maxn];
int pa[maxm];
int setsize[maxm];
int len=1;
void init()
{
for(int i=0;i<maxm;i++){
pa[i]=i;
setsize[i]=1;
}
}
int findset(int x)
{
if(pa[x]==x){
return x;
}
else{
pa[x]=findset(pa[x]);
return pa[x];
}
}
bool sameset(int x,int y)
{
return findset(x)==findset(y);
}
void unionset(int x,int y)
{
int fx=findset(x),fy=findset(y);
pa[fx]=fy;
setsize[fy]+=setsize[fx];
setsize[fx]=0;
}
struct Node{
int x,y;
Node(){}
Node(int x,int y):x(x),y(y){}
};
Node nodes[100];
void getxishu()
{
int cnt=0;
for(int i=1;i<=p;i++){
for(int j=i+1;j<=p;j++){
nodes[cnt++]=Node(i,j);
}
}
int sz=(1<<cnt);
for(int subset=0;subset<sz;subset++){
init();
int setcnt=0;
for(int i=0;i<cnt;i++){
if(subset&(1<<i)){
if(!sameset(nodes[i].x,nodes[i].y)) unionset(nodes[i].x,nodes[i].y);
setcnt++;
}
}
int sgn=setcnt&1?-1:1;
for(int i=0;i<len;i++){
f[maxm-1][i]=CD(1.0,0.0);
}
debug("subset%d sgn:%d\n",subset,sgn);
for(int i=1;i<=p;i++){
if(setsize[i]){
debug("sz%d\n",setsize[i]);
for(int j=0;j<len;j++)
f[maxm-1][j]=f[maxm-1][j]*f[setsize[i]][j];
}
}
for(int i=0;i<len;i++){
f[0][i]=f[0][i]+CD(sgn,0.0)*f[maxm-1][i];
}
}
}
int jiecheng[maxm];
void solve()
{
n=0;
scanf("%d %d",&m,&p);
for(int i=1;i<=m;i++){
scanf("%d",&a[i]);
if(n<a[i]) n=a[i];
}
n*=p;
len=1;
while(len<=n) len<<=1;
for(int i=0;i<7;i++){
for(int j=0;j<maxn;j++){
f[i][j]=CD(0.0,0.0);
}
}
for(int i=1;i<=m;i++){
for(int j=1;j<=p;j++){
f[j][j*a[i]]+=CD(1.0,0.0);
}
}
for(int i=1;i<=p;i++){
FFT(f[i],len,false);
}
getxishu();
FFT(f[0],len,true);
printf("Case #%d:\n",kase);
for(int i=0;i<len;i++){
double tmp=f[0][i].real()/jiecheng[p];
if(tmp>0.5){
printf("%d: %0.0lf\n",i,tmp);
}
}
puts("");
}
int main()
{
jiecheng[0]=1;
for(int i=1;i<maxm;i++){
jiecheng[i]=jiecheng[i-1]*i;
}
scanf("%d",&T);
for(kase=1;kase<=T;kase++){
solve();
}
return 0;
}