package com.work.home_2;
import java.util.ArrayList;
import java.util.List;
/**
* @author jun
* 说明:strassen算法,矩阵为偶数阶方阵,先将偶数阶方阵拆分成多个2阶方阵,
* 2阶方阵相乘时用strassen(2阶)方法,分块矩阵用传统方法
*/
public class SplitMatrixTwo {
/**
* 阶数是偶数时的strassen算法:矩阵A为n阶方阵,矩阵B为n阶方阵
*
* @param n:阶数
* @param alist:矩阵A拆分成多个2阶方阵后存放在集合alist中,按行存放
* @param blist:矩阵B拆分成多个2阶方阵后存放在集合alist中,按行存放
* @return 返回矩阵 A*B的结果
*/
public static int[][] strassenEven(int n, int[][] a, int[][] b) {
// 创建集合alist,用于接收拆分成2阶方阵的小矩阵
List<int[][]> alist = new ArrayList<>();
alist = SplitMatrixTwo.splitMatrix(n, a);
// 创建集合alist,用于接收拆分成2阶方阵的小矩阵
List<int[][]> blist = new ArrayList<>();
blist = SplitMatrixTwo.splitMatrix(n, b);
// 创建集合clist用以接收小矩阵相乘后的值
List<int[][]> clist = new ArrayList<>();
// 创建数组arr为最终矩阵
int[][] arr = new int[n][n];
// System.out.println("集合alist的长度为:"+alist.size());
// 遍历集合alist和集合blist中的元素,因为两个集合中的元素都是2阶矩阵,因此调用strassen(2阶)算法,把乘积放在集合clist中
for (int k = 0; k < alist.size(); k += (n / 2)) {
for (int g = 0; g < (n / 2); g++) {
int[][] c = new int[2][2];
for (int i = k, j = g; i < ((n / 2) + k); i++, j += (n / 2)) {
c = matrixAddition(2, c, strassen(2, alist.get(i), blist.get(j)));
}
clist.add(c);
}
}
// 把集合clist中的元素放到矩阵arr中
for (int x = 0, g = 0; x < n; x += 2, g += (n / 2)) {
for (int y = 0, k = g; y < n; y += 2, k++) {
for (int i = x; i < (2 + x); i++) {
for (int j = y; j < (2 + y); j++) {
arr[i][j] = clist.get(k)[i - x][j - y];
}
}
}
}
// 输出矩阵arr
System.out.println("strassen算法(阶数为偶数):俩矩阵相乘 A * B = ");
for (int i = 0; i < arr.length; i++) {
for (int j = 0; j < arr[i].length; j++) {
System.out.print(arr[i][j] + ",\t");
if (j == (n - 1))
System.out.println();
}
}
return arr;
}
/**
* @param n:矩阵阶数
* @param a:n阶方正
* @return 返回一个集合,集合中的元素皆为2阶方阵
*/
public static List<int[][]> splitMatrix(int n, int[][] a) {
List<int[][]> arrList = new ArrayList<>();
for (int x = 0; x < n; x += 2) {
for (int y = 0; y < n; y += 2) {
int[][] arr = new int[2][2];
for (int i = x; i < (2 + x); i++) {
for (int j = y; j < (2 + y); j++) {
arr[i - x][j - y] = a[i][j];
}
}
arrList.add(arr);
}
}
return arrList;
}
/**
* strassen 算法(2阶)
*/
public static int[][] strassen(int n, int[][] a, int[][] b) {
int[][] c = new int[n][n];
if (n != 2) {
System.out.println("非2阶,不可以调用");
return c;
} else {
// 运用strassen思想计算2阶矩阵相乘
int m1 = (a[0][0] + a[1][1]) * (b[0][0] + b[1][1]);
int m2 = (a[1][0] + a[1][1]) * b[0][0];
int m3 = a[0][0] * (b[0][1] - b[1][1]);
int m4 = a[1][1] * (b[1][0] - b[0][0]);
int m5 = (a[0][0] + a[0][1]) * b[1][1];
int m6 = (a[1][0] - a[0][0]) * (b[0][0] + b[0][1]);
int m7 = (a[0][1] - a[1][1]) * (b[1][0] + b[1][1]);
c[0][0] = m1 + m4 - m5 + m7;
c[0][1] = m3 + m5;
c[1][0] = m2 + m4;
c[1][1] = m1 + m3 - m2 + m6;
return c;
}
}
/**
* 两个矩阵相加
*/
public static int[][] matrixAddition(int n, int[][] a, int[][] b) {
// c 矩阵作为结果矩阵返回
int[][] c = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c[i][j] = a[i][j] + b[i][j];
}
}
return c;
}
/**
* 传统矩阵相乘,亦可用来验证strassen算法
*/
public static int[][] traditionMu(int n, int[][] a, int[][] b) {
int[][] c = new int[n][n];
for (int i = 0; i < a.length; i++) {
for (int j = 0; j < b.length; j++) {
for (int k = 0; k < n; k++) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
// 输出结果
System.out.println("传统算法:俩矩阵相乘 A * B = ");
for (int i = 0; i < c.length; i++) {
for (int j = 0; j < c[i].length; j++) {
System.out.print(c[i][j] + ",\t ");
if (j == (n - 1))
System.out.println();
}
}
return c;
}
}
package com.work.home_2;
import java.util.Scanner;
/**
* 测试函数
* @author jun
*/
public class TestEven {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
System.out.println("请输入矩阵的阶数(n为偶数)n = ");
int n = sc.nextInt();
sc.close();
// 测试矩阵 A
int[][] a = new int[n][n];
// 给矩阵 A 赋值
System.out.println("矩阵 A 阶数为:" + n + ",值为:");
for (int i = 0; i < a.length; i++) {
for (int j = 0; j < a[i].length; j++) {
a[i][j] = (int) (Math.random() * 10);
System.out.print(a[i][j] + ",\t ");
if (j == (n - 1))
System.out.println();
}
}
// 给矩阵 B 赋值
int[][] b = new int[n][n];
// 给矩阵 B 赋值
System.out.println("矩阵 B 阶数为:" + n + ",值为:");
for (int i = 0; i < b.length; i++) {
for (int j = 0; j < b[i].length; j++) {
b[i][j] = (int) (Math.random() * 10);
System.out.print(b[i][j] + ",\t ");
if (j == (n - 1))
System.out.println();
}
}
// 传统算法测试
SplitMatrixTwo.traditionMu(n, a, b);
// 阶数为偶数的strassen算法测试
SplitMatrixTwo.strassenEven(n, a, b);
}
}
import java.util.ArrayList;
import java.util.List;
/**
* @author jun
* 说明:strassen算法,矩阵为偶数阶方阵,先将偶数阶方阵拆分成多个2阶方阵,
* 2阶方阵相乘时用strassen(2阶)方法,分块矩阵用传统方法
*/
public class SplitMatrixTwo {
/**
* 阶数是偶数时的strassen算法:矩阵A为n阶方阵,矩阵B为n阶方阵
*
* @param n:阶数
* @param alist:矩阵A拆分成多个2阶方阵后存放在集合alist中,按行存放
* @param blist:矩阵B拆分成多个2阶方阵后存放在集合alist中,按行存放
* @return 返回矩阵 A*B的结果
*/
public static int[][] strassenEven(int n, int[][] a, int[][] b) {
// 创建集合alist,用于接收拆分成2阶方阵的小矩阵
List<int[][]> alist = new ArrayList<>();
alist = SplitMatrixTwo.splitMatrix(n, a);
// 创建集合alist,用于接收拆分成2阶方阵的小矩阵
List<int[][]> blist = new ArrayList<>();
blist = SplitMatrixTwo.splitMatrix(n, b);
// 创建集合clist用以接收小矩阵相乘后的值
List<int[][]> clist = new ArrayList<>();
// 创建数组arr为最终矩阵
int[][] arr = new int[n][n];
// System.out.println("集合alist的长度为:"+alist.size());
// 遍历集合alist和集合blist中的元素,因为两个集合中的元素都是2阶矩阵,因此调用strassen(2阶)算法,把乘积放在集合clist中
for (int k = 0; k < alist.size(); k += (n / 2)) {
for (int g = 0; g < (n / 2); g++) {
int[][] c = new int[2][2];
for (int i = k, j = g; i < ((n / 2) + k); i++, j += (n / 2)) {
c = matrixAddition(2, c, strassen(2, alist.get(i), blist.get(j)));
}
clist.add(c);
}
}
// 把集合clist中的元素放到矩阵arr中
for (int x = 0, g = 0; x < n; x += 2, g += (n / 2)) {
for (int y = 0, k = g; y < n; y += 2, k++) {
for (int i = x; i < (2 + x); i++) {
for (int j = y; j < (2 + y); j++) {
arr[i][j] = clist.get(k)[i - x][j - y];
}
}
}
}
// 输出矩阵arr
System.out.println("strassen算法(阶数为偶数):俩矩阵相乘 A * B = ");
for (int i = 0; i < arr.length; i++) {
for (int j = 0; j < arr[i].length; j++) {
System.out.print(arr[i][j] + ",\t");
if (j == (n - 1))
System.out.println();
}
}
return arr;
}
/**
* @param n:矩阵阶数
* @param a:n阶方正
* @return 返回一个集合,集合中的元素皆为2阶方阵
*/
public static List<int[][]> splitMatrix(int n, int[][] a) {
List<int[][]> arrList = new ArrayList<>();
for (int x = 0; x < n; x += 2) {
for (int y = 0; y < n; y += 2) {
int[][] arr = new int[2][2];
for (int i = x; i < (2 + x); i++) {
for (int j = y; j < (2 + y); j++) {
arr[i - x][j - y] = a[i][j];
}
}
arrList.add(arr);
}
}
return arrList;
}
/**
* strassen 算法(2阶)
*/
public static int[][] strassen(int n, int[][] a, int[][] b) {
int[][] c = new int[n][n];
if (n != 2) {
System.out.println("非2阶,不可以调用");
return c;
} else {
// 运用strassen思想计算2阶矩阵相乘
int m1 = (a[0][0] + a[1][1]) * (b[0][0] + b[1][1]);
int m2 = (a[1][0] + a[1][1]) * b[0][0];
int m3 = a[0][0] * (b[0][1] - b[1][1]);
int m4 = a[1][1] * (b[1][0] - b[0][0]);
int m5 = (a[0][0] + a[0][1]) * b[1][1];
int m6 = (a[1][0] - a[0][0]) * (b[0][0] + b[0][1]);
int m7 = (a[0][1] - a[1][1]) * (b[1][0] + b[1][1]);
c[0][0] = m1 + m4 - m5 + m7;
c[0][1] = m3 + m5;
c[1][0] = m2 + m4;
c[1][1] = m1 + m3 - m2 + m6;
return c;
}
}
/**
* 两个矩阵相加
*/
public static int[][] matrixAddition(int n, int[][] a, int[][] b) {
// c 矩阵作为结果矩阵返回
int[][] c = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
c[i][j] = a[i][j] + b[i][j];
}
}
return c;
}
/**
* 传统矩阵相乘,亦可用来验证strassen算法
*/
public static int[][] traditionMu(int n, int[][] a, int[][] b) {
int[][] c = new int[n][n];
for (int i = 0; i < a.length; i++) {
for (int j = 0; j < b.length; j++) {
for (int k = 0; k < n; k++) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
// 输出结果
System.out.println("传统算法:俩矩阵相乘 A * B = ");
for (int i = 0; i < c.length; i++) {
for (int j = 0; j < c[i].length; j++) {
System.out.print(c[i][j] + ",\t ");
if (j == (n - 1))
System.out.println();
}
}
return c;
}
}
package com.work.home_2;
import java.util.Scanner;
/**
* 测试函数
* @author jun
*/
public class TestEven {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
System.out.println("请输入矩阵的阶数(n为偶数)n = ");
int n = sc.nextInt();
sc.close();
// 测试矩阵 A
int[][] a = new int[n][n];
// 给矩阵 A 赋值
System.out.println("矩阵 A 阶数为:" + n + ",值为:");
for (int i = 0; i < a.length; i++) {
for (int j = 0; j < a[i].length; j++) {
a[i][j] = (int) (Math.random() * 10);
System.out.print(a[i][j] + ",\t ");
if (j == (n - 1))
System.out.println();
}
}
// 给矩阵 B 赋值
int[][] b = new int[n][n];
// 给矩阵 B 赋值
System.out.println("矩阵 B 阶数为:" + n + ",值为:");
for (int i = 0; i < b.length; i++) {
for (int j = 0; j < b[i].length; j++) {
b[i][j] = (int) (Math.random() * 10);
System.out.print(b[i][j] + ",\t ");
if (j == (n - 1))
System.out.println();
}
}
// 传统算法测试
SplitMatrixTwo.traditionMu(n, a, b);
// 阶数为偶数的strassen算法测试
SplitMatrixTwo.strassenEven(n, a, b);
}
}