改进Strassen’s算法 真正发挥优化矩阵乘法的优势

一、实验目的

根据文档,在本次实验中需要分别实现Strassen’s算法和普通的矩阵相乘算法,分析两种算法的理论时间复杂度,通过设计实验评估与算法实现相关的时间开销,并确定其结果与理论复杂度计算之间是否具有一致性。

二、实验步骤

实验硬件:Apple MacBook Pro
CPU:M1 Pro
系统:MacOS Ventura 13.5.1
软件:IntelliJ IDEA 2023.1.3 (Ultimate Edition)
Java运行时版本: 17.0.7+10-b829.16 aarch64
VM: OpenJDK 64-Bit Server VM by JetBrains s.r.o.
GC: G1 Young Generation, G1 Old Generation
Java虚拟机内存: 4096MB
核心数: 8核

  1. 生成测试数据
 static int n = 2048;
    public static void createTestFile(String filename){
        try {
            PrintWriter writer = new PrintWriter(new FileWriter(filename+".txt"));

            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    // 生成矩阵元素并写入文件
                    Random random = new Random();
                    int randomNumber = random.nextInt(1000);
                    writer.print(randomNumber);

                    if (j != n - 1) {
                        writer.print(" "); // 使用空格分隔每个元素
                    }
                }
                writer.println(); // 换行
            }

            writer.close();
            System.out.println("矩阵已成功写入到" + filename+".txt" +"文件中。");
        } catch (IOException e) {
            System.out.println("写入文件时发生错误。错误信息:" + e.getMessage());
        }
    }

    public static int[][] readTestFile(int n,String filename) {
        int[][] matrix = new int[n][n]; // 创建二维数组用于存储矩阵

        try {
            BufferedReader reader = new BufferedReader(new FileReader(filename+".txt"));

            String line;
            int row = 0;
            while ((line = reader.readLine()) != null && row < n) {
                String[] elements = line.split(" ");
                for (int col = 0; col < n && col < elements.length; col++) {
                    matrix[row][col] = Integer.parseInt(elements[col]);
                }
                row++;
            }
            reader.close();
        } catch (IOException e) {
            System.out.println("读取文件时发生错误。错误信息:" + e.getMessage());
        }
        return matrix;
    }
  1. 编写普通矩阵相乘算法
public static int[][] ordinary(int[][] matrix1, int[][] matrix2) {
        int n = matrix1.length;
        int[][] ans = new int[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                for (int k = 0; k < n; k++) {
                    ans[i][j] += matrix1[i][k] * matrix2[k][j];
                }
            }
        }
        return ans;
    }
  1. 编写Strassen’s算法
public static int[][] strassen(int[][] A, int[][] B) {
        int n = A.length;

        if (n <= 1) {
            return ordinary(A, B);
        }

        // 计算新矩阵的大小并将矩阵分割成四个子矩阵
        int newSize = n / 2;
        int[][] A11 = new int[newSize][newSize];
        int[][] A12 = new int[newSize][newSize];
        int[][] A21 = new int[newSize][newSize];
        int[][] A22 = new int[newSize][newSize];

        int[][] B11 = new int[newSize][newSize];
        int[][] B12 = new int[newSize][newSize];
        int[][] B21 = new int[newSize][newSize];
        int[][] B22 = new int[newSize][newSize];

        splitMatrix(A, A11, A12, A21, A22);
        splitMatrix(B, B11, B12, B21, B22);

        // 计算Strassen算法的七个中间矩阵
        int[][] M1 = strassen(addMatrix(A11, A22), addMatrix(B11, B22));
        int[][] M2 = strassen(addMatrix(A21, A22), B11);
        int[][] M3 = strassen(A11, subtractMatrix(B12, B22));
        int[][] M4 = strassen(A22, subtractMatrix(B21, B11));
        int[][] M5 = strassen(addMatrix(A11, A12), B22);
        int[][] M6 = strassen(subtractMatrix(A21, A11), addMatrix(B11, B12));
        int[][] M7 = strassen(subtractMatrix(A12, A22), addMatrix(B21, B22));

        // 计算Strassen算法的四个结果矩阵
        int[][] C11 = subtractMatrix(addMatrix(addMatrix(M1, M4), M7), M5);
        int[][] C12 = addMatrix(M3, M5);
        int[][] C21 = addMatrix(M2, M4);
        int[][] C22 = subtractMatrix(addMatrix(addMatrix(M1, M3), M6), M2);

        // 合并结果矩阵
        int[][] C = new int[n][n];
        combineMatrix(C, C11, C12, C21, C22);

        return C;
    }

public static void splitMatrix(int[][] A, int[][] A11, int[][] A12, int[][] A21, int[][] A22) {
        int newSize = A.length / 2;
        for (int i = 0; i < newSize; i++) {
            System.arraycopy(A[i], 0, A11[i], 0, newSize);
            System.arraycopy(A[i], newSize, A12[i], 0, newSize);
            System.arraycopy(A[i + newSize], 0, A21[i], 0, newSize);
            System.arraycopy(A[i + newSize], newSize, A22[i], 0, newSize);
        }
    }

public static void combineMatrix(int[][] C, int[][] C11, int[][] C12, int[][] C21, int[][] C22) {
        int newSize = C.length / 2;
        for (int i = 0; i < newSize; i++) {
            System.arraycopy(C11[i], 0, C[i], 0, newSize);
            System.arraycopy(C12[i], 0, C[i], newSize, newSize);
            System.arraycopy(C21[i], 0, C[i + newSize], 0, newSize);
            System.arraycopy(C22[i], 0, C[i + newSize], newSize, newSize);
        }
    }

public static int[][] addMatrix(int[][] A, int[][] B) {
        int n = A.length;
        int[][] C = new int[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                C[i][j] = A[i][j] + B[i][j];
            }
        }
        return C;
    }

public static int[][] subtractMatrix(int[][] A, int[][] B) {
        int n = A.length;
        int[][] C = new int[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                C[i][j] = A[i][j] - B[i][j];
            }
        }
        return C;
    }
  1. 编写优化的Strassen’s算法
public static int[][] optimizedStrassen(int[][] A, int[][] B,int k) {
        int n = A.length;

        if (n <= k) {
            return ordinary(A, B);
        }

        // 计算新矩阵的大小并将矩阵分割成四个子矩阵
        int newSize = n / 2;
        int[][] A11 = new int[newSize][newSize];
        int[][] A12 = new int[newSize][newSize];
        int[][] A21 = new int[newSize][newSize];
        int[][] A22 = new int[newSize][newSize];

        int[][] B11 = new int[newSize][newSize];
        int[][] B12 = new int[newSize][newSize];
        int[][] B21 = new int[newSize][newSize];
        int[][] B22 = new int[newSize][newSize];

        splitMatrix(A, A11, A12, A21, A22);
        splitMatrix(B, B11, B12, B21, B22);

        // 计算Strassen算法的七个中间矩阵
        int[][] M1 = optimizedStrassen(addMatrix(A11, A22), addMatrix(B11, B22),k);
        int[][] M2 = optimizedStrassen(addMatrix(A21, A22), B11,k);
        int[][] M3 = optimizedStrassen(A11, subtractMatrix(B12, B22),k);
        int[][] M4 = optimizedStrassen(A22, subtractMatrix(B21, B11),k);
        int[][] M5 = optimizedStrassen(addMatrix(A11, A12), B22,k);
        int[][] M6 = optimizedStrassen(subtractMatrix(A21, A11), addMatrix(B11, B12),k);
        int[][] M7 = optimizedStrassen(subtractMatrix(A12, A22), addMatrix(B21, B22),k);

        // 计算Strassen算法的四个结果矩阵
        int[][] C11 = subtractMatrix(addMatrix(addMatrix(M1, M4), M7), M5);
        int[][] C12 = addMatrix(M3, M5);
        int[][] C21 = addMatrix(M2, M4);
        int[][] C22 = subtractMatrix(addMatrix(addMatrix(M1, M3), M6), M2);

        // 合并结果矩阵
        int[][] C = new int[n][n];
        combineMatrix(C, C11, C12, C21, C22);

        return C;
    }

三、数据收集

编写测试代码,获得实验结果

public static void main(String[] args) {
//        createTestFile("matrix1");
//        createTestFile("matrix2");
        try{
            PrintWriter writer = new PrintWriter(new FileWriter("result.txt"));
            for (int n = 16; n <= 1024; n *= 2){
                int[][] matrix1 = readTestFile(n,"matrix1");
                int[][] matrix2 = readTestFile(n,"matrix2");
                long runtime = 0;
                for(int i=0;i<5;i++){ // 进行五次排序并取平均运行时间
                    long startTime = System.currentTimeMillis();
                    strassen(matrix1,matrix2);
                    long endTime = System.currentTimeMillis();
                    runtime += endTime-startTime;
                }
                System.out.println("Strassen Algorithm: n=" + n + "  runtime:" + runtime/5 + "ms");
                writer.println("Strassen Algorithm: n=" + n + "  runtime:" + runtime/5 + "ms"); // 将结果写入文件
                runtime = 0;
                for(int i=0;i<5;i++){ // 进行五次排序并取平均运行时间
                    long startTime = System.currentTimeMillis();
                    ordinary(matrix1,matrix2);
                    long endTime = System.currentTimeMillis();
                    runtime += endTime-startTime;
                }
                System.out.println("Ordinary Algorithm: n=" + n + "  runtime:" + runtime/5 + "ms");
                writer.println("Ordinary Algorithm: n=" + n + "  runtime:" + runtime/5 + "ms"); // 将结果写入文件
                runtime = 0;
                for(int k=4;k<=256;k*=2){
                    for(int i=0;i<5;i++){ // 进行五次排序并取平均运行时间
                        long startTime = System.currentTimeMillis();
                        optimizedStrassen(matrix1,matrix2,k);
                        long endTime = System.currentTimeMillis();
                        runtime += endTime-startTime;
                    }
                    System.out.println("Optimized Strassen Algorithm: n=" + n +" k=" + k +"  runtime:" + runtime/5 + "ms");
                    writer.println("Optimized Strassen Algorithm: n=" + n + " k=" + k+"  runtime:" + runtime/5 + "ms"); // 将结果写入文件
                }

            }
            writer.close(); // 关闭文件写入流
        }catch (IOException e) { // 异常处理
            System.out.println("Failed to write results to file. Error message: " + e.getMessage());
        }
        System.out.println("TEST FINISH"); // 输出测试结束信息
    }

结果导出到result.txt中,将其转换到result.xlsx中,方便后续分析数据。

首先,在数据规模的选取上,若方阵的大小不为2的n次方,在进行Strassen’s算法需要把方阵补全为2的n次方大小再来进行计算。由于本实验只是探究Strassen’s算法和普通的矩阵相乘算法运行速度之间的关系,不涉及更具体的应用场景,故选取的方阵的大小均为2的n次方,以便进行实验。

经过预测试发现,当数据规模小于n=16数量级时,运行速度太快,耗时均为0ms,无法得到运行速度变化,所以我选取的数据规模n=16开始。
当数据规模达到n=2048数量级规模时,运行速度太慢,所以我选取的数据规模最大到n=1024。
各个规模的计算均进行5次取平均值,减小由于某次排序可能存在的异常情况或错误对结果的影响,提高数据的稳定性和可靠性。

另外,在对Strassen’s算法进行测试时发现,单纯的Strassen’s算法运行效率十分慢。分析原因是递归的堆栈调用和矩阵复制的值传递耗费了太多时间,导致Strassen’s算法比普通的矩阵相乘算法运行速度更慢,算法优化起到了反效果。

因此,我决定对Strassen’s算法进行优化,优化思路同lab1对归并算法的优化,在矩阵规模小于等于k时不再使用Strassen’s算法改为使用普通的矩阵相乘算法。在这个思路下设计实验,并对不同的k值进行测试,观察其能否比普通矩阵相乘算法有更好的表现。

四、结果与讨论

  1. 理论分析
    对于普通的矩阵相乘算法,三重循环每层循环的次数为n,所以时间复杂度显然为
    O ( n 3 ) O(n^3) O(n3)
    Strassen’s算法的本质上是把一个矩阵划分成四个小的矩阵,每个矩阵的规模变成原来的1/2,通过一系列矩阵构造并进行加减运算得到乘积,在这个过程中减少了一次矩阵的乘法运算,以此来减少计算使用的时间。因此,Strassen’s算法的时间复杂度可以表示为
    T ( n ) = 7 T ( n 2 ) + O ( n 2 ) T(n) = 7T\left(\frac{n}{2}\right) + O(n^2) T(n)=7T(2n)+O(n2)
    其中T(n)表示计算两个 n×n 矩阵相乘所需的时间,T(n/2)为每个子问题需要的时间,通过Master Method求解递归关系,可以得到Strassen算法的总时间复杂度为
    O ( n log ⁡ 2 7 ) O(n^{\log_2 7}) O(nlog27)
    大约等于$ O(n^{2.81})$。这意味着Strassen算法比传统的矩阵乘法算法更高效。然而,在实际应用中,因为有递归的堆栈调用和矩阵复制的值传递,Strassen算法的常数因子较大,因此只有在特殊优化下才能体现出其优势。对于小规模的矩阵乘法,传统的算法可能更快。

  2. 实验结果
    实验结果如下(单位ms):

    数据规模 (2^n)普通算法k=4k=8k=16k=32k=64k=128k=256Strassen算法
    4000000003
    50100000014
    60400000077
    7123622111525
    8121554117121011123510
    910110162891258474818724482
    10252770972006832558497551639175735

使用该数据画图(其中未优化的Strassen’s算法运行时间太长,未画在图中)

在这里插入图片描述

如图,以普通矩阵相乘算法为基准,可以看出在k=4时,由于递归的堆栈调用层数太深及矩阵复制的值传递过多,优化的Strassen’s算法运行速度并不如普通矩阵相乘算法。
但随着k值变大,在k大于8时,优化的Strassen’s算法运行速度均比普通矩阵相乘算法要快,这说明Strassen’s算法减少矩阵乘法次数以减小时间复杂度的方法起了作用,其中在k=64时算法运行速度最快。即使在k=256时(即Strassen’s算法只进行了几次矩阵乘法优化),优化的Strassen’s算法运行速度仍然比普通矩阵相乘算法快。

五、结论

本次实验主要对比了普通的矩阵相乘算法和Strassen’s算法在不同规模下的运行速度,并通过优化的Strassen’s算法来提高效率。通过实验结果分析。

首先,根据理论分析,普通矩阵相乘算法的时间复杂度为O(n3),而Strassen’s算法的时间复杂度约为O(n2.81),即Strassen’s算法理论上更高效。但在实际应用中,由于递归调用和值传递开销较大,Strassen’s算法的常数因子较大,只有在特殊优化下才能体现出其优势。
实验结果显示,k值大于8时,优化的Strassen’s算法开始展现出优势,运行速度明显快于普通矩阵相乘算法。特别是在k=64时,优化的Strassen’s算法达到最佳性能。
这一结果说明,通过减少矩阵乘法次数,Strassen’s算法成功地降低了时间复杂度,尽管Strassen’s算法需要进行递归调用和值传递的开销,但通过合适的优化(如设置合适的k值),仍然可以使计算效率超过普通矩阵相乘算法。

综上所述,本次实验通过实际测试验证了Strassen’s算法在适当的优化下相比普通矩阵相乘算法更高效的结论。同时,也展示了优化策略中k值的重要性,选择合适的k值可以提高优化的Strassen’s算法的运行效率。这些研究结果对于理解和应用矩阵乘法算法具有一定的指导意义,为实际问题中的矩阵运算提供了重要的参考依据。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值