一、实验目的
根据文档,在本次实验中需要分别实现Strassen’s算法和普通的矩阵相乘算法,分析两种算法的理论时间复杂度,通过设计实验评估与算法实现相关的时间开销,并确定其结果与理论复杂度计算之间是否具有一致性。
二、实验步骤
实验硬件:Apple MacBook Pro
CPU:M1 Pro
系统:MacOS Ventura 13.5.1
软件:IntelliJ IDEA 2023.1.3 (Ultimate Edition)
Java运行时版本: 17.0.7+10-b829.16 aarch64
VM: OpenJDK 64-Bit Server VM by JetBrains s.r.o.
GC: G1 Young Generation, G1 Old Generation
Java虚拟机内存: 4096MB
核心数: 8核
- 生成测试数据
static int n = 2048;
public static void createTestFile(String filename){
try {
PrintWriter writer = new PrintWriter(new FileWriter(filename+".txt"));
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
// 生成矩阵元素并写入文件
Random random = new Random();
int randomNumber = random.nextInt(1000);
writer.print(randomNumber);
if (j != n - 1) {
writer.print(" "); // 使用空格分隔每个元素
}
}
writer.println(); // 换行
}
writer.close();
System.out.println("矩阵已成功写入到" + filename+".txt" +"文件中。");
} catch (IOException e) {
System.out.println("写入文件时发生错误。错误信息:" + e.getMessage());
}
}
public static int[][] readTestFile(int n,String filename) {
int[][] matrix = new int[n][n]; // 创建二维数组用于存储矩阵
try {
BufferedReader reader = new BufferedReader(new FileReader(filename+".txt"));
String line;
int row = 0;
while ((line = reader.readLine()) != null && row < n) {
String[] elements = line.split(" ");
for (int col = 0; col < n && col < elements.length; col++) {
matrix[row][col] = Integer.parseInt(elements[col]);
}
row++;
}
reader.close();
} catch (IOException e) {
System.out.println("读取文件时发生错误。错误信息:" + e.getMessage());
}
return matrix;
}
- 编写普通矩阵相乘算法
public static int[][] ordinary(int[][] matrix1, int[][] matrix2) {
int n = matrix1.length;
int[][] ans = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
for (int k = 0; k < n; k++) {
ans[i][j] += matrix1[i][k] * matrix2[k][j];
}
}
}
return ans;
}
- 编写Strassen’s算法
public static int[][] strassen(int[][] A, int[][] B) {
int n = A.length;
if (n <= 1) {
return ordinary(A, B);
}
// 计算新矩阵的大小并将矩阵分割成四个子矩阵
int newSize = n / 2;
int[][] A11 = new int[newSize][newSize];
int[][] A12 = new int[newSize][newSize];
int[][] A21 = new int[newSize][newSize];
int[][] A22 = new int[newSize][newSize];
int[][] B11 = new int[newSize][newSize];
int[][] B12 = new int[newSize][newSize];
int[][] B21 = new int[newSize][newSize];
int[][] B22 = new int[newSize][newSize];
splitMatrix(A, A11, A12, A21, A22);
splitMatrix(B, B11, B12, B21, B22);
// 计算Strassen算法的七个中间矩阵
int[][] M1 = strassen(addMatrix(A11, A22), addMatrix(B11, B22));
int[][] M2 = strassen(addMatrix(A21, A22), B11);
int[][] M3 = strassen(A11, subtractMatrix(B12, B22));
int[][] M4 = strassen(A22, subtractMatrix(B21, B11));
int[][] M5 = strassen(addMatrix(A11, A12), B22);
int[][] M6 = strassen(subtractMatrix(A21, A11), addMatrix(B11, B12));
int[][] M7 = strassen(subtractMatrix(A12, A22), addMatrix(B21, B22));
// 计算Strassen算法的四个结果矩阵
int[][] C11 = subtractMatrix(addMatrix(addMatrix(M1, M4), M7), M5);
int[][] C12 = addMatrix(M3, M5);
int[][] C21 = addMatrix(M2, M4);
int[][] C22 = subtractMatrix(addMatrix(addMatrix(M1, M3), M6), M2);
// 合并结果矩阵
int[][] C = new int[n][n];
combineMatrix(C, C11, C12, C21, C22);
return C;
}
public static void splitMatrix(int[][] A, int[][] A11, int[][] A12, int[][] A21, int[][] A22) {
int newSize = A.length / 2;
for (int i = 0; i < newSize; i++) {
System.arraycopy(A[i], 0, A11[i], 0, newSize);
System.arraycopy(A[i], newSize, A12[i], 0, newSize);
System.arraycopy(A[i + newSize], 0, A21[i], 0, newSize);
System.arraycopy(A[i + newSize], newSize, A22[i], 0, newSize);
}
}
public static void combineMatrix(int[][] C, int[][] C11, int[][] C12, int[][] C21, int[][] C22) {
int newSize = C.length / 2;
for (int i = 0; i < newSize; i++) {
System.arraycopy(C11[i], 0, C[i], 0, newSize);
System.arraycopy(C12[i], 0, C[i], newSize, newSize);
System.arraycopy(C21[i], 0, C[i + newSize], 0, newSize);
System.arraycopy(C22[i], 0, C[i + newSize], newSize, newSize);
}
}
public static int[][] addMatrix(int[][] A, int[][] B) {
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
public static int[][] subtractMatrix(int[][] A, int[][] B) {
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
- 编写优化的Strassen’s算法
public static int[][] optimizedStrassen(int[][] A, int[][] B,int k) {
int n = A.length;
if (n <= k) {
return ordinary(A, B);
}
// 计算新矩阵的大小并将矩阵分割成四个子矩阵
int newSize = n / 2;
int[][] A11 = new int[newSize][newSize];
int[][] A12 = new int[newSize][newSize];
int[][] A21 = new int[newSize][newSize];
int[][] A22 = new int[newSize][newSize];
int[][] B11 = new int[newSize][newSize];
int[][] B12 = new int[newSize][newSize];
int[][] B21 = new int[newSize][newSize];
int[][] B22 = new int[newSize][newSize];
splitMatrix(A, A11, A12, A21, A22);
splitMatrix(B, B11, B12, B21, B22);
// 计算Strassen算法的七个中间矩阵
int[][] M1 = optimizedStrassen(addMatrix(A11, A22), addMatrix(B11, B22),k);
int[][] M2 = optimizedStrassen(addMatrix(A21, A22), B11,k);
int[][] M3 = optimizedStrassen(A11, subtractMatrix(B12, B22),k);
int[][] M4 = optimizedStrassen(A22, subtractMatrix(B21, B11),k);
int[][] M5 = optimizedStrassen(addMatrix(A11, A12), B22,k);
int[][] M6 = optimizedStrassen(subtractMatrix(A21, A11), addMatrix(B11, B12),k);
int[][] M7 = optimizedStrassen(subtractMatrix(A12, A22), addMatrix(B21, B22),k);
// 计算Strassen算法的四个结果矩阵
int[][] C11 = subtractMatrix(addMatrix(addMatrix(M1, M4), M7), M5);
int[][] C12 = addMatrix(M3, M5);
int[][] C21 = addMatrix(M2, M4);
int[][] C22 = subtractMatrix(addMatrix(addMatrix(M1, M3), M6), M2);
// 合并结果矩阵
int[][] C = new int[n][n];
combineMatrix(C, C11, C12, C21, C22);
return C;
}
三、数据收集
编写测试代码,获得实验结果
public static void main(String[] args) {
// createTestFile("matrix1");
// createTestFile("matrix2");
try{
PrintWriter writer = new PrintWriter(new FileWriter("result.txt"));
for (int n = 16; n <= 1024; n *= 2){
int[][] matrix1 = readTestFile(n,"matrix1");
int[][] matrix2 = readTestFile(n,"matrix2");
long runtime = 0;
for(int i=0;i<5;i++){ // 进行五次排序并取平均运行时间
long startTime = System.currentTimeMillis();
strassen(matrix1,matrix2);
long endTime = System.currentTimeMillis();
runtime += endTime-startTime;
}
System.out.println("Strassen Algorithm: n=" + n + " runtime:" + runtime/5 + "ms");
writer.println("Strassen Algorithm: n=" + n + " runtime:" + runtime/5 + "ms"); // 将结果写入文件
runtime = 0;
for(int i=0;i<5;i++){ // 进行五次排序并取平均运行时间
long startTime = System.currentTimeMillis();
ordinary(matrix1,matrix2);
long endTime = System.currentTimeMillis();
runtime += endTime-startTime;
}
System.out.println("Ordinary Algorithm: n=" + n + " runtime:" + runtime/5 + "ms");
writer.println("Ordinary Algorithm: n=" + n + " runtime:" + runtime/5 + "ms"); // 将结果写入文件
runtime = 0;
for(int k=4;k<=256;k*=2){
for(int i=0;i<5;i++){ // 进行五次排序并取平均运行时间
long startTime = System.currentTimeMillis();
optimizedStrassen(matrix1,matrix2,k);
long endTime = System.currentTimeMillis();
runtime += endTime-startTime;
}
System.out.println("Optimized Strassen Algorithm: n=" + n +" k=" + k +" runtime:" + runtime/5 + "ms");
writer.println("Optimized Strassen Algorithm: n=" + n + " k=" + k+" runtime:" + runtime/5 + "ms"); // 将结果写入文件
}
}
writer.close(); // 关闭文件写入流
}catch (IOException e) { // 异常处理
System.out.println("Failed to write results to file. Error message: " + e.getMessage());
}
System.out.println("TEST FINISH"); // 输出测试结束信息
}
结果导出到result.txt中,将其转换到result.xlsx中,方便后续分析数据。
首先,在数据规模的选取上,若方阵的大小不为2的n次方,在进行Strassen’s算法需要把方阵补全为2的n次方大小再来进行计算。由于本实验只是探究Strassen’s算法和普通的矩阵相乘算法运行速度之间的关系,不涉及更具体的应用场景,故选取的方阵的大小均为2的n次方,以便进行实验。
经过预测试发现,当数据规模小于n=16数量级时,运行速度太快,耗时均为0ms,无法得到运行速度变化,所以我选取的数据规模n=16开始。
当数据规模达到n=2048数量级规模时,运行速度太慢,所以我选取的数据规模最大到n=1024。
各个规模的计算均进行5次取平均值,减小由于某次排序可能存在的异常情况或错误对结果的影响,提高数据的稳定性和可靠性。
另外,在对Strassen’s算法进行测试时发现,单纯的Strassen’s算法运行效率十分慢。分析原因是递归的堆栈调用和矩阵复制的值传递耗费了太多时间,导致Strassen’s算法比普通的矩阵相乘算法运行速度更慢,算法优化起到了反效果。
因此,我决定对Strassen’s算法进行优化,优化思路同lab1对归并算法的优化,在矩阵规模小于等于k时不再使用Strassen’s算法改为使用普通的矩阵相乘算法。在这个思路下设计实验,并对不同的k值进行测试,观察其能否比普通矩阵相乘算法有更好的表现。
四、结果与讨论
-
理论分析
对于普通的矩阵相乘算法,三重循环每层循环的次数为n,所以时间复杂度显然为
O ( n 3 ) O(n^3) O(n3)
Strassen’s算法的本质上是把一个矩阵划分成四个小的矩阵,每个矩阵的规模变成原来的1/2,通过一系列矩阵构造并进行加减运算得到乘积,在这个过程中减少了一次矩阵的乘法运算,以此来减少计算使用的时间。因此,Strassen’s算法的时间复杂度可以表示为
T ( n ) = 7 T ( n 2 ) + O ( n 2 ) T(n) = 7T\left(\frac{n}{2}\right) + O(n^2) T(n)=7T(2n)+O(n2)
其中T(n)表示计算两个 n×n 矩阵相乘所需的时间,T(n/2)为每个子问题需要的时间,通过Master Method求解递归关系,可以得到Strassen算法的总时间复杂度为
O ( n log 2 7 ) O(n^{\log_2 7}) O(nlog27)
大约等于$ O(n^{2.81})$。这意味着Strassen算法比传统的矩阵乘法算法更高效。然而,在实际应用中,因为有递归的堆栈调用和矩阵复制的值传递,Strassen算法的常数因子较大,因此只有在特殊优化下才能体现出其优势。对于小规模的矩阵乘法,传统的算法可能更快。 -
实验结果
实验结果如下(单位ms):数据规模 (2^n) 普通算法 k=4 k=8 k=16 k=32 k=64 k=128 k=256 Strassen算法 4 0 0 0 0 0 0 0 0 3 5 0 1 0 0 0 0 0 0 14 6 0 4 0 0 0 0 0 0 77 7 1 23 6 2 2 1 1 1 525 8 12 155 41 17 12 10 11 12 3510 9 101 1016 289 125 84 74 81 87 24482 10 2527 7097 2006 832 558 497 551 639 175735
使用该数据画图(其中未优化的Strassen’s算法运行时间太长,未画在图中)
如图,以普通矩阵相乘算法为基准,可以看出在k=4时,由于递归的堆栈调用层数太深及矩阵复制的值传递过多,优化的Strassen’s算法运行速度并不如普通矩阵相乘算法。
但随着k值变大,在k大于8时,优化的Strassen’s算法运行速度均比普通矩阵相乘算法要快,这说明Strassen’s算法减少矩阵乘法次数以减小时间复杂度的方法起了作用,其中在k=64时算法运行速度最快。即使在k=256时(即Strassen’s算法只进行了几次矩阵乘法优化),优化的Strassen’s算法运行速度仍然比普通矩阵相乘算法快。
五、结论
本次实验主要对比了普通的矩阵相乘算法和Strassen’s算法在不同规模下的运行速度,并通过优化的Strassen’s算法来提高效率。通过实验结果分析。
首先,根据理论分析,普通矩阵相乘算法的时间复杂度为O(n3),而Strassen’s算法的时间复杂度约为O(n2.81),即Strassen’s算法理论上更高效。但在实际应用中,由于递归调用和值传递开销较大,Strassen’s算法的常数因子较大,只有在特殊优化下才能体现出其优势。
实验结果显示,k值大于8时,优化的Strassen’s算法开始展现出优势,运行速度明显快于普通矩阵相乘算法。特别是在k=64时,优化的Strassen’s算法达到最佳性能。
这一结果说明,通过减少矩阵乘法次数,Strassen’s算法成功地降低了时间复杂度,尽管Strassen’s算法需要进行递归调用和值传递的开销,但通过合适的优化(如设置合适的k值),仍然可以使计算效率超过普通矩阵相乘算法。
综上所述,本次实验通过实际测试验证了Strassen’s算法在适当的优化下相比普通矩阵相乘算法更高效的结论。同时,也展示了优化策略中k值的重要性,选择合适的k值可以提高优化的Strassen’s算法的运行效率。这些研究结果对于理解和应用矩阵乘法算法具有一定的指导意义,为实际问题中的矩阵运算提供了重要的参考依据。