整数乘法---FFT 的递归实现

 package fft;

/**
 * input data mode:
 *  the number array 1,2,3,4 use base U = 2^ENTRYSIZE replace the real number
 * 1*U^0 + 2*U^1 + 3*U^2 + 4*U^3
 *
 * the output number array is the same as the input number array
 * the max term of the cirrocumulus must not be over P
 * So 2*ENTRYSIZE+log(max(mag.length)) < log(P)
 */

import java.lang.*;
import java.math.*;
import java.util.*;
//支持利用递归FFT算法计算乘法的大整数类的乘法方法c

class NumberOverRangeException extends Exception {

    private int[] A;

    NumberOverRangeException(int[] B) {
        A = B;
    }

    public void printNumber() {
        for (int i = 0; i < A.length; i++) {
            System.out.print(A[i] + " ");
        }
        System.out.println();
    }
}

public class BigInt {

    protected int signum = 0;           //neg = -1,0 = 0,pos = 1

    protected int[] mag;                //magnitude in little-endian format
    // assume that datas have been formated to use a base 2^(ENTRYSIZE)

    public final static int MAXN = 134217728; //Maximum value for n 2^27

    public final static int ENTRYSIZE = 10;     //Bits per entry in mag, base = 2^ENTRYSIZE

    protected final static long P = 2013265921; // The prime 15*2^{27}+1

    protected final static int OMEGA = 440564289;   //Root of unity 31^{15}mod P

    protected final static int TWOINV = 1006632961; //2^{-1}mod P


    public BigInt(int signum, int[] mag) {
        this.signum = signum;
        this.mag = new int[mag.length];
        for (int i = 0; i < mag.length; i++) {
            this.mag[i] = mag[i];
        }
    }

    public static void main(String[] args) {
        // TODO code application logic here
        int k = (int) Math.pow(2, ENTRYSIZE) - 1;
        System.out.println(k);
        int len = 512;
        int[] mag = new int[len];
        for (int i = 0; i < len; i++) {
            mag[i] = k;
        }
        BigInt mag1 = new BigInt(1, mag);
        BigInt mag2 = new BigInt(1, mag);
        try {
            checkBound(mag1.mag);
            checkBound(mag2.mag);
        } catch (NumberOverRangeException ex) {
            System.out.println("NumberOverRangeException!");
            System.out.print("The invalid Array is ");
            ex.printNumber();
            return;
        }
        BigInt mag3 = mag1.multiply(mag2);
        for (int i = 0; i < mag3.mag.length; i++) {
            System.out.print(mag3.mag[i] + " ");
        }
        System.out.println();
        compareUnit(mag1.mag, mag2.mag, mag3.mag);
    }

    public static void checkBound(int[] A) throws NumberOverRangeException {

//数据越界:2*ENTRYSIZE+logn位>P的位数

        if (A.length > MAXN / 2) {
            throw new NumberOverRangeException(A);
        }
        if (ENTRYSIZE > 15) {
            throw new NumberOverRangeException(A);
        }
        if (P / (int) Math.pow(Math.pow(2, ENTRYSIZE) - 1, 2) < A.length) {
            throw new NumberOverRangeException(A);
        }
        int ck = 0;
        for (int i = 0; i < A.length; i++) {
            if (A[i] >= Math.pow(2, ENTRYSIZE)) {
                ck = 1;
                break;
            }
        }
        if (ck == 1) {
            throw new NumberOverRangeException(A);
        }
    }

    public static void compareUnit(int[] mag1, int[] mag2, int[] mag3) {//验证结果个位

        int pu = (int) Math.pow(2, ENTRYSIZE) % 10;
        int sum1 = 0, sum2 = 0, sum3 = 0;
        int pi = 1;
        for (int i = 0; i < mag1.length; i++) {
            sum1 = (sum1 + (mag1[i] % 10) * pi) % 10;
            pi = (pi * pu) % 10;
        }
        pi = 1;
        for (int i = 0; i < mag2.length; i++) {
            sum2 = (sum2 + (mag2[i] % 10) * pi) % 10;
            pi = (pi * pu) % 10;
        }
        System.out.println((sum1 * sum2) % 10);
        pi = 1;
        for (int i = 0; i < mag3.length; i++) {
            int tt = mag3[i] % 10;
            sum3 = (sum3 + tt * pi) % 10;
            pi = (pi * pu) % 10;
        }
        System.out.println(sum3);
    }

    public BigInt multiply(BigInt val) {
        int n = makePowerOfTwo(Math.max(mag.length, val.mag.length)) * 2;
        int signResult = signum * val.signum;
        int[] A = padWithZeros(mag, n);              //copies mag into A padded w/0's

        int[] B = padWithZeros(val.mag, n);          // copies val.mag into B padded w/0's

        int[] root = rootsOfUnity(n);               // creates all n roots of unity

        int[] C = new int[n];                       // result array for A*B

        int[] AF = new int[n];                       // result array for FFT of A

        int[] BF = new int[n];                      // result array for FFT of B

        FFT(A, root, n, 0, AF);
        FFT(B, root, n, 0, BF);
        for (int i = 0; i < n; i++) {
            AF[i] = (int) (((long) AF[i] * (long) BF[i]) % P); //Component multiply

        }
        reverseRoots(root);                             // Reverse roots to create inverse roots

        inverseFFT(AF, root, n, 0, C);                      // leaves inverse FFT result in C

        propagateCarries(C);                            // Convert C to right no. bits per entry

        return new BigInt(signResult, C);
    }

    protected static int makePowerOfTwo(int length) {
        int i;
        for (i = 1; i < length; i *= 2);
        return i;
    }

    protected static int[] padWithZeros(int[] mag, int n) {
        int[] tmp = new int[n];
        for (int i = 0; i < mag.length; i++) {
            tmp[i] = mag[i];
        }
        for (int i = mag.length; i < n; i++) {
            tmp[i] = 0;
        }
        return tmp;
    }
    //FFT算法的递归实现

    public static void FFT(int[] A, int[] root, int n, int base, int[] Y) {
        int prod;
        if (n == 1) {
            Y[base] = A[base];
            return;
        }
        inverseShuffle(A, n, base);   //inverse shuffle to separete evens and odds
        FFT(A, root, n / 2, base, Y);     //results in Y[base] to Y[base+n/2-1]
        FFT(A, root, n / 2, base + n / 2, Y); //results in Y[base+n/2] to Y[base+n-1]
        int j = A.length / n;
        for (int i = 0; i < n / 2; i++) {
            prod = (int) (((long) root[i * j] * Y[base + n / 2 + i]) % P);
            Y[base + n / 2 + i] = (int) (((long) Y[base + i] + P - prod) % P);
            Y[base + i] = (int) (((long) Y[base + i] + prod) % P);
        }
    }

    public static void inverseFFT(int[] A, int[] root, int n, int base, int[] Y) {
        int inverseN = modInverse(n);   //n^(-1)

        FFT(A, root, n, base, Y);
        for (int i = 0; i < n; i++) {
            Y[i] = (int) (((long) Y[i] * inverseN) % P);
        }
    }
    //递归FFT的支持方法

    protected static int modInverse(int n) { //assume n is power of two

        int result = 1;
        for (long twoPower = 1; twoPower < n; twoPower *= 2) { //n = 2^t

            result = (int) (((long) result * TWOINV) % P);
        }
        return result;
    }
    /*
     * 逆混洗:对每个子数组A,把A 中的偶数下标元素移到A的低半部分,把A中的奇数下标元素移到A的高半部分
     * */
    protected static void inverseShuffle(int[] A, int n, int base) {
        int shift;
        int[] sp = new int[n];
        for (int i = 0; i < n / 2; i++) {//Unshullfe A into the scratch space

            shift = base + 2 * i;
            sp[i] = A[shift];       //an even index

            sp[i + n / 2] = A[shift + 1]; //an odd index

        }
        for (int i = 0; i < n; i++) {
            A[base + i] = sp[i];//copy back to A

        }
    }

    protected static int[] rootsOfUnity(int n) {   //assumes n is power of 2

        int nthroot = OMEGA;
        for (int t = MAXN; t > n; t /= 2) //Find prim. nth root of unity       
        {
            nthroot = (int) (((long) nthroot * nthroot) % P);
        }
        int[] roots = new int[n];
        int r = 1;                                  //r will run through all nth roots of unity

        for (int i = 0; i < n; i++) {
            roots[i] = r;
            r = (int) (((long) r * nthroot) % P);
        }
        return roots;
    }

    protected static void reverseRoots(int[] root) {
        int temp;
        for (int i = 1; i < (root.length + 1) / 2; i++) {
            temp = root[i];
            root[i] = root[root.length - i];
            root[root.length - i] = temp;
        }
    }

    protected static void propagateCarries(int[] A) {
        int carry;
        carry = 0;
        for (int i = 0; i < A.length; i++) {
            System.out.print(A[i] + " ");
        }
        System.out.println();
        for (int i = 0; i < A.length; i++) {
            A[i] = A[i] + carry;
            //System.out.println(A[i]);
            carry = A[i] >>> ENTRYSIZE;             //逻辑右移

            A[i] = A[i] - (carry << ENTRYSIZE);     //算术左移

        }
    //System.out.println("carry "+carry);
    }
}

/*
FFT(a,w)
 * 输入:[a0,a1,...,an-1]] [w0,w1,...,wn-1,wn] n为2 的幂
 * 输出:向量Y
 * if(n =1) return Y = A;//大写表示向量
 * x = w0
 * aeven = [a0,a2,...,an-2]
 * aodd = [a1, a3,...,an-1]
 * {递归调用,由归约性质可知,pow(w,2)为第n/2个本原单位根]
 * yeven = FFT(aeven,pow(w,2))
 * yodd = FFt(aodd,pow(w,2))
 * {组合步骤,利用 x = pow(w,i)}
 * for i = 0 to n/2-1 do
 * yi = yieven + x*yiodd
 * y(i+n/2) = yieven + x*yiodd
 * y(i+n/2) = yieven - x*yiodd (反射性质)
 * x = x*w
 * return Y;
 */

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值