zz:http://www.strongczq.com/2012/03/srm536-div1-3-binarypolynomialdivone.html
题目原文:http://community.topcoder.com/stat?c=problem_statement&pm=11798&rd=14728
题目大意:
定义二进制多项式如下所示:
P(x) = a[0] * x0 + a[1] * x1 + ... + a[n] * xn
其中a[i]的取值只能是0或者1,并且最高此项的系数a[n]必须是1。
二进制多项式的乘法:与普通多项式相同,但是得到的多项式各系数必须要取2的模。也就是说二进制多项式相乘的结果也是二进制多项式。
按照这种乘法方式,二进制多项式也可以进行指数为正整数的幂运算。
先给定P(x)的系数a[], 求对P(x)取m次幂之后的二进制多项式中,第k项是0还是1。
数据规模:n为[0,49], m为[1, 10^16], k为[0, n*m]
思路:
由于m最大值可以达到10^16,直接做m次乘法肯定不靠谱。对于这种情况,由于二进制多项式的乘法运算比较特殊,包含了对2取模的过程,应该先找个简单的例子试试二进制多项式的乘法运算,看看有没有什么规律。由于题目中主要是考虑二进制多项式的幂次,所以我们可以试试计算一下二进制多项式的平方。考虑P(x)的平方P2(x),P2(x)的第c项的系数可以表示为:
a[0] * a[c] + a[1] * a[c - 1] + ... + a[c/2]*a[c - c/2] + ... + a[c - 1]*a[0]
可以看出以上式子左右存在对称性。如果c为奇数的话,所有项都存在对称项,那么最终值肯定是0;如果c为偶数的话,最终值为a[c/2]^2, 所以有:
P2(x)=a[0] * x0 + a[1] * x2 + ... + a[n] * x2n=P(x2)
根据以上的规律,可以得到:P2^b(x)= P(x2^b), 所以可以将Pm(x)表示成多个P(x2^b)的乘积。
考虑m的二进制表示,如果第b1、b2...、bj位(从右到左)是1,那么Pm(x)可以表示成
P(x2^b1) * P(x2^b2)*...* P(x2^bj)
所以问题可以转化为,满足以下条件的组合(i1, i2, ..., ij)的数量是奇数还是偶数:
a[i1]=a[i2]=...=a[ij]=1
i1*2^b1 + i2*2^b2 + ...+ ij*2^bj=k
该问题类似于背包算法,可以考虑使用dp算法。本人的解法使用了迭代实现的dp算法。用S(p)表示所有 k - i1*2^b1 - ... - ip*2^bp 的可行值的集合。那么,针对集合S(p)的每一个k'值,(ip+1, ..., ij)的组合数之和的奇偶性与本题最终答案的奇偶性相同。
S(p)的生成方式为,假设第p-1步产生的集合为S(p-1),则第p步根据S(p-1)中的每一个k',考虑a[ip]的所有可能取值下的k-a[ip]生成S(p)集合。如果简单的这么生成,那么用不了几步,这个集合大小就会超过内存的承受范围,所以还需要加两个限制条件:
a. 如果一个元素偶数次塞入S(p),那么该元素对最终的奇偶性不会产生影响,不需要考虑该元素
b. 显然只有k - i1*2^b1 - ... - ip*2^bp 是 2^bp+1 的整数倍的情况下才有可能存在(ip+1, ..., ij)组合,所以如果不是的话就不需要考虑该元素。由于k - i1*2^b1 - ... - ip*2^bp 的取 值范围必然小于n*2^bp+1 ,所以这样的元素个数必然不超过n个,完全符合内存要求。
第j步之后,根据以上方式生成的S(j)集合,判断0是否属于该集合,如果是则说明答案为1,否则为0。
Java代码:
public class BinaryPolynomialDivOne {
public int findCoefficient(int[] a, long m, long k) {
Set<Long> dp = new HashSet<Long>();
dp.add(Long.valueOf(k));
for (int i = 0; (1L << i) <= m; ++i) {
Set<Long> tmp = new HashSet<Long>();
for (Long v : dp) {
if (((1L << i) & m) == 0) {
if (v.longValue() % 2 == 0) {
tmp.add(Long.valueOf(v.longValue() / 2));
}
} else {
for (int j = 0; j < a.length; ++j) {
if (a[j] == 0) {
continue;
}
long nv = v.longValue() - j;
if (nv % 2 == 0 && nv >= 0) {
nv /= 2;
Long NV = Long.valueOf(nv);
if (tmp.contains(NV)) {
tmp.remove(NV);
} else {
tmp.add(NV);
}
}
}
}
}
if (tmp.size() == 0) {
return 0;
}
dp = tmp;
}
return dp.contains(Long.valueOf(0)) ? 1 : 0;
}
}