关于母函数网上有一大堆解析,不过解析具体代码的好像不多,这里就简要介绍一下好了。仅供初学者参考,老鸟请路过-.-
这篇博客虽然写的是hdu 1398,不过我们还是从hdu 1028开始比较好,因为这道题是最赤裸裸的母函数模板。
母函数说白了就是计算几个式子的乘积,如
(1 + x + x^1 + x^2)(1 + x^2) = 1 + x + 2x^2 + x^3 + x^4
通过母函数我们要得出的就是右式。我们设两个变量c1和c2,c2是临时数组,先不管它,我们将c1的下标设为指数,值设为这个指数的系数,如5x^3就是c1[3] = 5。
先看一看代码
#include<iostream>
using namespace std;
#define maxn 125
int c1[maxn], c2[maxn];
int main()
{
int n;
while(~scanf("%d", &n))
{
for(int i = 0; i <= n; i++) // --- 第一步
{
c1[i] = 1;
c2[i] = 0;
}
for(int i = 2; i <= n; i++) // --- 第二步
{
for(int j = 0; j <= n; j++) // --- 第三步
for(int k = 0; k+j <= n;k+=i) // --- 第四步
c2[k+j] += c1[j]; // --- 第五步
for(int i = 0; i <= n; i++) // 第六步
{
c1[i] = c2[i];
c2[i] = 0;
}
}
printf("%d\n", c1[n]);
}
return 0;
}
hdu 1028要计算的是(1 + x + x^2 + x^3 + ...) * (1 + x ^2 + x^4 + ...) * (1 + x^3 + x^6 + ...)
与手动计算一样,先是第一个括号乘以第二个括号,得到结果后再继续与第三个括号相乘,如此下去。
那么我们首先初始化第一个括号的参数,也就是上述代码中的第一步。然后从第二个括号开始乘,那么就是上述的第二步。接下来的j依次指向第一个括号中的每一个数。
第四步是关键,k指向的是第二个括号中的每一个数,说到这里可能还是比较模糊,我们就来手动执行一下。我们假设hdu1028的n现在是2
(1 + x + x^2) * (1 + x^2)
首先j指向第一个括号中的第一个数,也就是1,然后依次与第二个括号中的数相乘,那么k循环一遍后我们得到了1 + x^2,接着我们把它记录到c2这个临时数组中
所以是c2[j+k] += c1[j]。j+k指的是相乘后的指数,c1[j]指的是【第一个括号中,指数为j的数的系数】。
为什么是这样写的呢?个人所见,可以从两方面理解:一是因为,第二个括号中的系数都是1,因为个数的含义体现在了指数上,这个很好理解。二是因为,这个表达式前两个括号相加后,是合并到第一个括号中的,然后继续与"第二个“括号相乘,这里的"第二个"括号实际上是第三个括号。
例如这个表达式:(1 + 2x + 3x^2) * (1 + x^2),我们会得到如下值:
(1 + 2x + 3x^2) * (1 + x^2) = 1 + x^2 + 2x + 2x^3 + 3x^2 + 2x^4,然后是合并同指数的系数,得到1 + 2x + 4x^2 + 2x^3 + 3x^4
所以是c2[j+k] += c1[j]。从这里可以看出,母函数就是经过不断计算,将指数一层层叠加到第一个括号中。
现在c2数组存了更新后的"第一个括号的系数",所以要放回c1中,供下次相乘使用。
现在回过来看看hdu 1398,(1 + x + x^2 + x^3 + ...) * (1 + x^2 + x^4 + ...) * (1 + x^9 + x^18 + ...)
同样的,初始化0~n都是1,然后从第二个括号开始计算,因为第一个括号永远都是n+1项,所以j是从0到n。
第i个括号中的指数增量是i*i,所以代码就成型了。
#include<iostream>
using namespace std;
#define maxn 310
int c1[maxn], c2[maxn];
int main()
{
int n;
while(~scanf("%d", &n) && n)
{
memset(c1, 0, sizeof(c1));
memset(c2, 0, sizeof(c2));
for(int i = 0; i <= n; i++)
c1[i] = 1;
for(int i = 2; i*i <= n; i++)
{
for(int j = 0; j <= n; j++)
for(int k = 0; j + k <= n; k += i*i)
c2[j+k] += c1[j];
for(int i = 0; i <= n; i++)
{
c1[i] = c2[i];
c2[i] = 0;
}
}
printf("%d\n", c1[n]);
}
return 0;
}