看来之前所说的有些错误的地方,这个算法只能处理到800位的乘法。之后做了一个测试,即比较用数论变换方法和用普通方法做大整数乘法的时间消耗。如下图,表示了程序运行所花的时间随着问题规模的增加而增加的曲线。其中红色的曲线是普通乘法,而黑色的曲线则是基于数论变换的大整数乘法。可以看到当问题规模增大时,我们的算法的时间性能就体现出来了。不过不够完美的是,基于数论变换的方法,只能处理到800多位(81n<=65537)。
当然,跟Java库中的BigInteger类中的乘法相比,我的实现代码就慢了很多。主要是BigInteger中不是用十进制来保存的,而是用了更大的进制(256还是65536来着?也记不清了),因此它的效率是来自于算法的常数因子比较小。在可测的范围之内都比我们的算法效率要高。这也可以说是理论与实际之间的一些差距吧。或者说BigInteger中的乘法是综合各方面考虑的权衡。
测试的代码如下:
1.Test.java
package NumericTransformation;
import java.awt.Graphics;
import javax.swing.JFrame;
import javax.swing.JPanel;
public class Test extends JPanel {
/**
* @param args
*/
public static void main(String[] args) {
// TODO Auto-generated method stub
JFrame jFrame=new JFrame();
jFrame.setBounds(0, 0, 800, 800);
jFrame.add(new Test());
jFrame.setVisible(true);
jFrame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
}
@Override
public void paint(Graphics g) {
// TODO Auto-generated method stub
super.paint(g);
//g.drawString("this is our first draw string function call", 10, 10);
for(int i=10;i<800;i+=5){
try {
new NumericTransformation().testOnce(i, i, g);
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}
2.NumericTransformation.java
package NumericTransformation;
import java.awt.Color;
import java.awt.Graphics;
import java.math.BigInteger;
import java.util.Random;
import java.util.Scanner;
/**
* 实现了基于数论变换的大整数乘法
* @author wwf
*
*/
public class NumericTransformation {
public static final int M=65537;
public static final int R=3;//3 是65537的一个本原根,也可以说是生成元
public static final int RP=21846;
public static final int MAX_POW=65536;
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
// TODO Auto-generated method stub
Scanner in=new Scanner(System.in);
/*System.out.println("input two integer numbers : ");
String numa=in.next();
String numb=in.next();
System.out.println("the result of multiplication is : "+new NumericTransformation().multiply(numa, numb));
*/
}
void testOnce(int length,int position,Graphics g) throws Exception{
long time1=0,time2=0;
System.out.println("calculate");
for(int i=0;i<500;i++){
Random rand=new Random();
StringBuffer sba=new StringBuffer();
StringBuffer sbb=new StringBuffer();
for(int j=0;j<length;j++){
sba.append((rand.nextInt(10)+20)%10);
sbb.append((rand.nextInt(10)+20)%10);
}
String sa=sba.toString();
String sb=sbb.toString();
long original=System.currentTimeMillis();
multiply(sa, sb);
time1+=System.currentTimeMillis()-original;
original=System.currentTimeMillis();
int[] result=new int[sba.length()+sbb.length()+1];
for(int j=0;j<sba.length();j++){
for(int k=0;k<sbb.length();k++){
result[j+k]+=(sba.charAt(j)-'0')*(sbb.charAt(k)-'0');
}
}
for(int j=0;j<result.length-1;j++){
result[j+1]+=result[j]/10;
result[j]%=10;
}
time2+=System.currentTimeMillis()-original;
}
g.setColor(Color.BLACK);
g.fillRect(position, (int) (500-time1/10), 5, 5);
g.setColor(Color.red);
g.fillOval(position, (int) (500-time2/10), 5, 5);
}
/**
* 这个是程序对外的接口
* @param a
* @param b
* @return
* @throws Exception
*/
public String multiply(String a, String b) throws Exception{
if(a.startsWith("-")&&b.startsWith("-")){
return multiply(a.substring(1), b.substring(1));
}else if(((!a.startsWith("-")&&(!b.startsWith("-"))))){
long [] pa=new long[a.length()];
long [] pb=new long[b.length()];
for(int i=0;i<a.length();i++){
char c=a.charAt(i);
if(!Character.isDigit(c)){
throw new Exception("number format error!!!");
}
pa[i]=c-'0';
}
for(int i=0;i<b.length();i++){
char c=b.charAt(i);
if(!Character.isDigit(c)){
throw new Exception("number format error!!!");
}
pb[i]=c-'0';
}
long [] presult=this.numMultiply(pa, pb);
StringBuffer result=new StringBuffer();
for(long i:presult){
result.append(i);
}
while(result.charAt(0)=='0'&&result.length()>1){
result.delete(0, 1);
}
return result.toString();
}else if(a.startsWith("-")){
return "-"+multiply(a.substring(1), b);
}else{
return "-"+multiply(a, b.substring(1));
}
}
/**
* a和b是两个乘数的十进制表示形式(两个乘数都是非负整数)
* @param a
* @param b
* @return
* @throws Exception
*/
private long[] numMultiply(long []a,long[] b) throws Exception{
if(a.length*81>M&&b.length*81>M){
throw new Exception("number too large to perform FFT based Multiplication");
}
long[] result=polyMultiply(a, b);
/*for(int i=0;i<result.length;i++){
System.out.print(result[i]+" ");
}System.out.println();*/
for(int i=result.length-1;i>=0;i--){
if(result[i]>10){
result[i-1]+=result[i]/10;
result[i]%=10;
}
}
/*for(int i=0;i<result.length;i++){
System.out.print(result[i]+" ");
}System.out.println();*/
return result;
}
/**
* 多项式乘法
* @param pa
* @param pb
* @return
*/
private long[] polyMultiply(long [] pa,long[]pb){
/*
* expand pa and pb
*/
int dim=1;
while(dim<pa.length+pb.length){
dim*=2;
}
long [] tmp=new long[dim];
for(int i=0;i<pa.length;i++){
tmp[tmp.length-i-1]=pa[pa.length-i-1];
}
pa=tmp;
tmp=new long[dim];
for(int i=0;i<pb.length;i++){
tmp[tmp.length-i-1]=pb[pb.length-i-1];
}
pb=tmp;
//System.out.println("dim = " + dim);
long r=1;
for(long i=0;i<MAX_POW/dim;i++){
r=(r*R)%M;
}
//System.out.println("in polyMultiply : r = "+r);
long[] psa=transform(pa,r);
long[] psb=transform(pb,r);
long[] psp=new long[psa.length];
for(int i=0;i<psp.length;i++){
psp[i]=(psa[i]*psb[i])%M;
}
long rr=1;
for(int i=0;i<MAX_POW/dim;i++){
rr=(rr*RP)%M;
}
long[] result=transform(psp,rr);
for(int i=0;i<result.length;i++){
result[i]/=dim;
}
return result;
}
/**
* 变换
* 假设pa的长度是n=2^t
* @param pa
* @return
*/
private long[] transform(long[] pa,long r){
if(pa.length==1){
long[] result=new long[1];
result[0]=(r*pa[0])%M;
//System.out.println("in transform : ");
//System.out.println("r = "+r);
//System.out.println("pa : "+pa[0]);
//System.out.println("result : "+result[0]);
return result;
}
//System.out.prlongln("r = "+r);
long []pa1=new long[pa.length/2];
long []pa2=new long[pa.length/2];
for(int i=0;i<pa.length/2;i++){
pa1[i]=pa[2*i];
pa2[i]=pa[2*i+1];
}
//System.out.println("r*r%M = "+(long)((long)r*r)%M);
long[] ps1=transform(pa1,(long)(((long)r)*r)%M);
long[] ps2=transform(pa2,(long)(((long)r)*r)%M);
long[]result=new long[pa.length];
long tmpr=r;
for(int i=0;i<result.length/2;i++){
result[i]=(tmpr*ps1[i]+ps2[i])%M;
result[i+result.length/2]=((-tmpr*ps1[i]+ps2[i])%M+M)%M;
tmpr=(tmpr*r)%M;
}
//System.out.println("in transform");
//System.out.println("r = "+r);
//System.out.println("pa : ");
/*for(int i=0;i<pa.length;i++){
System.out.print(pa[i]+" ");
}System.out.println();
System.out.println("result : ");
for(int i=0;i<result.length;i++){
System.out.print(result[i]+" ");
}System.out.println();*/
return result;
}
}