Java实现Strassen

Strassen算法也是一种基于分治思想的算法。首先我们用普通的分治方法来实现矩阵的乘法。

这里我是用行下标和列下标来拆分矩阵的,并没有去复制矩阵中的元素,和算法导论中给出的思路是一样的。

***注意:**以下的程序只适用NN的矩阵。

首先,我定义一个关于矩阵的对象,Java代码如下:

package hanxl.insist.beans;

/**
 * 矩阵类
 * 之所以设置起始和结束下标是因为会去用下标partition矩阵
 */
public class Matrix {
	/**
	 * 矩阵的行起始下标
	 * 用于partition矩阵
	 */
	private int rowStartIndex;

	/**
	 * 矩阵的行结束下标
	 * 用于partition矩阵
	 */
	private int rowEndIndex;

	/**
	 * 矩阵的列起始下标
	 * 用于partition矩阵
	 */
	private int columnStartIndex;

	/**
	 * 矩阵的列结束下标
	 * 用于partition矩阵
	 */
	private int columnEndIndex;

	/**
	 * 矩阵中的数元素
	 */
	private int[][] elements;
	
	/**
	 * 根据给定的二维数组构建一个矩阵
	 * @param elements
	 */
	public Matrix(int[][] elements) {
		this(0, elements.length - 1, 0, elements[0].length - 1, elements);
	}

	/**
	 * 根据给定的4个下标来拆分给定的数组元素,并构建矩阵,矩阵中的元素与数组中的元素不一定一致
	 * @param rowStartIndex
	 * @param rowEndIndex
	 * @param columnStartIndex
	 * @param columnEndIndex
	 * @param elements
	 */
	public Matrix(int rowStartIndex, int rowEndIndex, int columnStartIndex, int columnEndIndex, int[][] elements) {
		this.rowStartIndex = rowStartIndex;
		this.rowEndIndex = rowEndIndex;
		this.columnStartIndex = columnStartIndex;
		this.columnEndIndex = columnEndIndex;
		this.elements = elements;
	}

	/**
	 * 构造一个指定行和列的空矩阵
	 * @param row
	 * @param column
	 */
	public Matrix(int row, int column) {
		this(new int[row][column]);
	}
	
	
	public static Matrix add(Matrix a, Matrix b) {
		Matrix matrix = new Matrix(a.getRows(),a.getColumns());
		int[][] resultElements = matrix.getElements();
		int[][] aelements = a.getElements();
		int[][] belements = b.getElements();
		
		for (int i = 0; i < belements.length; i++)
			for (int j = 0; j < belements.length; j++)
				resultElements[i][j] = aelements[i][j] + belements[i][j];
		
		return matrix;
	}
	
	/**
	 * 根据当前矩阵的4个下标打印矩阵
	 */
	public void printMatrix() {
		int[][] e = this.getElements();
		for (int i = this.getRowStartIndex(); i <= this.getRowEndIndex(); i++) {
			for (int j = this.getColumnStartIndex(); j <= this.getColumnEndIndex(); j++) {
				System.out.print(e[i][j] + "  ");
			}
			System.out.println();
		}
	}

	/**
	 * 获取矩阵的行
	 */
	public int getRows() {
		return rowEndIndex - rowStartIndex + 1;
	}
	
	/**
	 * 获取矩阵的列
	 */
	public int getColumns() {
		return columnEndIndex - columnStartIndex + 1;
	}
	
	public int getRowStartIndex() {
		return rowStartIndex;
	}

	public void setRowStartIndex(int rowStartIndex) {
		this.rowStartIndex = rowStartIndex;
	}

	public int getRowEndIndex() {
		return rowEndIndex;
	}

	public void setRowEndIndex(int rowEndIndex) {
		this.rowEndIndex = rowEndIndex;
	}

	public int getColumnStartIndex() {
		return columnStartIndex;
	}

	public void setColumnStartIndex(int columnStartIndex) {
		this.columnStartIndex = columnStartIndex;
	}

	public int getColumnEndIndex() {
		return columnEndIndex;
	}

	public void setColumnEndIndex(int columnEndIndex) {
		this.columnEndIndex = columnEndIndex;
	}

	public int[][] getElements() {
		return elements;
	}

	public void setElements(int[][] elements) {
		this.elements = elements;
	}
}

关于分治思想的代码如下:

package hanxl.insist.fourchapter;

import hanxl.insist.beans.Matrix;

public class MatrixMultiply {
	
	public static void main(String[] args) {
		int[][] aelements = {{1, 3,4,5}, {7,2,6, 5},{2,4,2, 4}, {6,4,8,2}};
		Matrix a = new Matrix(aelements); 
		a.printMatrix();
		System.out.println("-------------------");
		
		int[][] belements = {{6,4,5 ,8}, {4,2,3 ,2}, {1,1,1,3}, {6,9,2,7}};
		Matrix b = new Matrix(belements);
		b.printMatrix();
		System.out.println("-------------------");
		
		Matrix r = recursiveMultiply(a, b); 
		r.printMatrix(); 
	}
	
	public static Matrix recursiveMultiply(Matrix a, Matrix b) {
		Matrix c = new Matrix(a.getRows(), a.getColumns());  //根据矩阵a的起始下标......等创建的,并不是用下标约束,而是一个真实的矩阵
		
		if ( c.getRows() == 1 )
			c.getElements()[0][0] = a.getElements()[a.getRowStartIndex()][a.getColumnStartIndex()] * b.getElements()[b.getRowStartIndex()][b.getColumnStartIndex()];  // base case
		else {
			Matrix[] amatrixs = partition(a);
			Matrix a11 = amatrixs[0];
			Matrix a12 = amatrixs[1];
			Matrix a21 = amatrixs[2];
			Matrix a22 = amatrixs[3];
			
			Matrix[] bmatrixs = partition(b);
			Matrix b11 = bmatrixs[0];
			Matrix b12 = bmatrixs[1];
			Matrix b21 = bmatrixs[2];
			Matrix b22 = bmatrixs[3];
			
			Matrix[] cmatrixs = partition(c);  //这些小矩阵的elements对象是与c一样的,只不过用下标将其限制住了,它们和c相比也不是同一个对象
			Matrix c11 = cmatrixs[0];
			Matrix c12 = cmatrixs[1];
			Matrix c21 = cmatrixs[2];
			Matrix c22 = cmatrixs[3];
			
			c11 = Matrix.add(recursiveMultiply(a11, b11),recursiveMultiply(a12, b21));
			c12 = Matrix.add(recursiveMultiply(a11, b12),recursiveMultiply(a12, b22));
			c21 = Matrix.add(recursiveMultiply(a21, b11),recursiveMultiply(a22, b21));
			c22 = Matrix.add(recursiveMultiply(a21, b12),recursiveMultiply(a22, b22));
			
			c = merge(c11, c12, c21, c22);
		}
		
		return c;
	}

	/**
	 * 把4个小矩阵合并成一个大矩阵
	 * @param c11
	 * @param c12
	 * @param c21
	 * @param c22
	 * @return
	 */
	public static Matrix merge(Matrix c11, Matrix c12, Matrix c21, Matrix c22) {
		Matrix matrix = new Matrix(c11.getRows() * 2, c11.getColumns() * 2);
		int[][] elements = matrix.getElements();
		int length = c11.getElements().length;
		
		for (int i = 0; i < length; i++) {
			for (int j = 0; j < length; j++) {
				elements[i][j] = c11.getElements()[i][j];
				elements[i][j + length] = c12.getElements()[i][j];
				elements[i + length][j] = c21.getElements()[i][j];
				elements[i + length][j + length] = c22.getElements()[i][j];
			}
		}
		
		return matrix;
	}

	/**
	 * 把一个大矩阵切分成四个小矩阵封装到数组之中
	 * @param matrix
	 * @return
	 */
	public static Matrix[] partition( Matrix matrix ) {
		Matrix[] matrixs = new Matrix[4];
		
		int rowStart = matrix.getRowStartIndex();
		int rowEnd = matrix.getRowEndIndex();
		int rowMid = ( rowStart + rowEnd ) / 2;
		
		int[][] elements = matrix.getElements();
		
		int columnStart = matrix.getColumnStartIndex();
		int columnEnd = matrix.getColumnEndIndex();
		int columnMid = ( columnStart + columnEnd ) / 2;
		
		matrixs[0] = new Matrix(rowStart, rowMid, columnStart, columnMid, elements);
		matrixs[1] = new Matrix(rowStart, rowMid, columnMid + 1, columnEnd, elements);
		matrixs[2] = new Matrix(rowMid + 1, rowEnd, columnStart, columnMid, elements);
		matrixs[3] = new Matrix(rowMid + 1, rowEnd, columnMid + 1, columnEnd, elements);
		
		return matrixs;
	}
}

Strassen算法Java代码如下:

package hanxl.insist.fourchapter;

import hanxl.insist.beans.Matrix;

public class Strassen {

	public static void main(String[] args) {
		int[][] aelements = { { 1, 3, 4, 5 }, { 7, 2, 6, 5 }, { 2, 4, 2, 4 }, { 6, 4, 8, 2 } }; //
		Matrix a = new Matrix(aelements); //

		int[][] belements = { { 6, 4, 5, 8 }, { 4, 2, 3, 2 }, { 1, 1, 1, 3 }, { 6, 9, 2, 7 } }; //
		Matrix b = new Matrix(belements); //

		Matrix r = recursiveMultiply(a, b); //
		System.out.println("----这是结果----");
		r.printMatrix(); //
	}

	public static Matrix recursiveMultiply(Matrix a, Matrix b) {
		Matrix c = new Matrix(a.getRows(), a.getColumns());

		if (c.getRows() == 1)
			c.getElements()[0][0] = a.getElements()[a.getRowStartIndex()][a.getColumnStartIndex()]
					* b.getElements()[b.getRowStartIndex()][b.getColumnStartIndex()]; // base
																						// case
		else {
			Matrix[] amatrixs = MatrixMultiply.partition(a);
			Matrix a11 = amatrixs[0];
			Matrix a12 = amatrixs[1];
			Matrix a21 = amatrixs[2];
			Matrix a22 = amatrixs[3];

			Matrix[] bmatrixs = MatrixMultiply.partition(b);
			Matrix b11 = bmatrixs[0];
			Matrix b12 = bmatrixs[1];
			Matrix b21 = bmatrixs[2];
			Matrix b22 = bmatrixs[3];

			Matrix s1 = calculate(b12, b22, "-"); // s1为堂堂正正的一个矩阵,并没有用下标限制
			Matrix s2 = calculate(a11, a12, "+");
			Matrix s3 = calculate(a21, a22, "+");
			Matrix s4 = calculate(b21, b11, "-");
			Matrix s5 = calculate(a11, a22, "+");
			Matrix s6 = calculate(b11, b22, "+");
			Matrix s7 = calculate(a12, a22, "-");
			Matrix s8 = calculate(b21, b22, "+");
			Matrix s9 = calculate(a11, a21, "-");
			Matrix s10 = calculate(b11, b12, "+");

			Matrix p1 = recursiveMultiply(a11, s1);
			Matrix p2 = recursiveMultiply(s2, b22);
			Matrix p3 = recursiveMultiply(s3, b11);
			Matrix p4 = recursiveMultiply(a22, s4);
			Matrix p5 = recursiveMultiply(s5, s6);
			Matrix p6 = recursiveMultiply(s7, s8);
			Matrix p7 = recursiveMultiply(s9, s10);

			Matrix c11 = calculate(calculate(p5, p4, "+"), calculate(p6, p2, "-"), "+");
			Matrix c12 = calculate(p1, p2, "+");
			Matrix c21 = calculate(p3,p4, "+");
			Matrix c22 = calculate(calculate(p5, p1, "+"), calculate(p7, p3, "+"), "-");

			c = MatrixMultiply.merge(c11, c12, c21, c22);
		}
		return c;

	}

	private static Matrix calculate(Matrix b12, Matrix b22, String operator) {
		Matrix matrix = new Matrix(b12.getRows(), b12.getColumns());
		int[][] resultElements = matrix.getElements();
		int rp = 0;
		int cp = 0;

		int[][] aelements = b12.getElements();

		int[][] belements = b22.getElements();
		int brp = b22.getRowStartIndex();
		int bcp = b22.getColumnStartIndex();

		for (int i = b12.getRowStartIndex(); i <= b12.getRowEndIndex(); i++) {
			for (int j = b12.getColumnStartIndex(); j <= b12.getColumnEndIndex(); j++) {
				if ("-".equals(operator))
					resultElements[rp][cp] = aelements[i][j] - belements[brp][bcp];
				else
					resultElements[rp][cp] = aelements[i][j] + belements[brp][bcp];
				bcp++;
				cp++;
			}
			cp = 0;
			bcp = b22.getColumnStartIndex();
			brp++;
			rp++;
		}

		return matrix;
	}
}

这个算法中的calculate方法之所以不用Matrix对象中的add方法,是因为这个方法的参数是被用下标限制的矩阵,而那个add方法并没有做出任何限制,就是和二维数组是一样的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值