AI数据结构-矩阵

   最近打算学习 机器学习,发现里面算法矩阵用的挺多的,然后就学习了下矩阵,在这里分享一下学习心得

  基本结构

  •    m x n 的矩阵一般写成:
  • 矩阵与向量相乘: 其实 可以看做矩阵与只有一列或者一行的 矩阵相乘。如 :矩阵 A[n][m] x B[m][1] = AB[n][1]
  • 矩阵与矩阵的乘积: A[m][n] xB[n,k] = AB[m][k]
  1. 将矩阵B视看做由k个列向量组成;
  2. 将矩阵A分别与向量b[i] 计算乘积,将结果依次作为矩阵AB的 第i列;

  矩阵接口实现代码


/**
 * 矩阵接口
 * QQ群:528344775
 *
 * @author chengchaochao
 * @create 2017-12-29 10:39
 **/
public interface Matrix {
    /**
     * 返回矩阵是否为方阵
     * <ul>
     * <li>true</li>
     * <li>false</li>
     * </ul>
     */
    boolean isSquare();

    /**
     * 返回矩阵行大小
     */
    int getRowDimension();

    /**
     * 返回矩阵列大小
     */
    int getColDimension();

    double [] [] getData();

    Matrix createMatrix(int rowSize, int colSize);

    /**
     * 创建默认值为0 的指定大小的Array2DMatrix
     */
    static Matrix createArray2DMatrix(int rowSize, int colSize){
        return new Array2DMatrix(rowSize,colSize,0.0D);
    }

    /**
     * 创建指定默认值 的指定大小的Array2DMatrix
     */
    static Matrix createArray2DMatrix(int rowSize, int colSize ,double defaultValue){
        return new Array2DMatrix(rowSize,colSize,defaultValue);
    }

    /**
     *完全复制
     */
    Matrix copy();
    /**
     *两个矩阵 相加(每个元素相加)
     */
    Matrix add(Matrix ohter);
    /**
     *两个矩阵 相加(每个元素与某数相加)
     */
    Matrix add(double num);

    /**
     *两个矩阵 相减(每个元素相减)
     */
    Matrix subtract(Matrix ohter);

    /**
     *两个矩阵 相减(每个元素与某数相减)
     */
    Matrix subtract(double num);

    /**
     *两个矩阵 相乘(每个元素相乘)
     */
    Matrix valueMultiply(Matrix ohter);
    /**
     *两个矩阵 相乘
     * 矩阵 A ,B 必须 A 的行 等于  B 的列
     */
    Matrix matrixMultiply(Matrix ohter);
    /**
     *两个矩阵 相乘(每个元素与某数相乘)
     */
    Matrix multiply(double num);

    double setValue(int rowIndex ,int colIndex, double value);

    double getValue(int rowIndex ,int colIndex);

    /***
     * 追加列
     */
    void addCol(double value);

    /***
     * 追加多列
     */
    void addCol(double [] values);

    /**
     * 追加行
     */
    void addRow(double value);
    /**
     * 追加多行
     */
    void addRow(double [] values);

    /**
     * 矩阵转置
     */
    Matrix transposition();


    default  Stream<DoubleStream> stream(){
        return Arrays.stream(getData()).map(DoubleStream::of);
    }

    String toStringByStand();

      矩阵抽象类实现代码

/**
 * 矩阵主实现
 * QQ群:528344775
 *
 * @author chengchaochao
 * @create 2018-01-04 8:54
 **/
public abstract class AbstractMatrix implements Matrix {


    protected AbstractMatrix(int rowDimension, int columnDimension) {
        if (rowDimension < 1) {
            throw new BizRunTimeException("行尺寸错误,不能小于1");
        } else if (columnDimension < 1) {
            throw new BizRunTimeException("列尺寸错误,不能小于1");
        }

    }

    public AbstractMatrix() {
    }

    @Override
    public boolean isSquare() {
        return this.getColDimension() == this.getRowDimension();
    }

    @Override
    public abstract int getRowDimension();

    @Override
    public abstract int getColDimension();

    @Override
    public abstract Matrix createMatrix(int rowSize, int colSize);

    public abstract Matrix createMatrix(double[][] data);

    @Override
    public abstract Matrix copy();


    /**
     * 矩阵 与 矩阵间元素 运算
     *
     * @Author: chengchaochao
     * @Date: 2018/1/4 13:08
     */
    protected Matrix operation(Matrix ohter, BiFunction<Double, Double, Double> operater) {

        int rowCount = this.getRowDimension();
        int columnCount = this.getColDimension();

        // 检查行和列
        checkMatrixRowAndColEquase(ohter);
        Matrix result = this.createMatrix(rowCount, columnCount);

        for (int row = 0; row < rowCount; ++row) {
            for (int col = 0; col < columnCount; ++col) {
                result.setValue(row, col, operater.apply(this.getValue(row, col), ohter.getValue(row, col)));
            }
        }

        return result;

    }

    /**
     * 矩阵 与某数 间的 运算
     *
     */
    protected Matrix operationByScalar(BiFunction<Double, Double, Double> operater, double num) {
        // 不用检查行和列
        int rowCount = this.getRowDimension();
        int columnCount = this.getColDimension();

        Matrix result = this.createMatrix(rowCount, columnCount);

        for (int row = 0; row < rowCount; ++row) {
            for (int col = 0; col < columnCount; ++col) {
                result.setValue(row, col, operater.apply(this.getValue(row, col), num));
            }
        }

        return result;

    }


    @Override
    public Matrix add(Matrix ohter) {
        return this.operation(ohter, ADD_OPERATER);
    }

    @Override
    public Matrix add(double num) {
        return this.operationByScalar(ADD_OPERATER, num);
    }

    @Override
    public Matrix subtract(Matrix ohter) {
        return this.operation(ohter, SUBTRACT_OPERATER);
    }

    @Override
    public Matrix subtract(double num) {
        return this.operationByScalar(SUBTRACT_OPERATER, num);
    }

    @Override
    public Matrix valueMultiply(Matrix ohter) {
        return this.operation(ohter, MULTIPLY_OPERATER);
    }

    /**
     * 矩阵转置 , 行变 列 , 列变 行
     *
     * @see <a>https://en.wikipedia.org/wiki/Transpose</a>
     */
    @Override
    public Matrix transposition() {

        int nCols = this.getColDimension();
        int nRows = this.getRowDimension();

        double[][] newData = new double[nCols][nRows];
        for (int row = 0; row < nCols; row++) {
            for (int col = 0; col < nRows; col++) {
                newData[row][col] = this.getValue(col, row);
            }
        }
        this.setData(newData);
        return this;
    }

    /**
     * <pre>
     *     当矩阵A的列数等于矩阵B的行数时,A与B可以相乘。
     *     矩阵C的行数等于矩阵A的行数,C的列数等于B的列数。
     *    乘积C的第m行第n列的元素等于矩阵A的第m行的元素与矩阵B的第n列对应元素乘积之和。
     * </pre>
     *
     * @Author: chengchaochao
     * @Date: 2018/1/5 9:16
     */
    @Override
    public Matrix matrixMultiply(Matrix ohter) {

        int nRows = this.getRowDimension();
        int nCols = this.getColDimension();
        int newCols = ohter.getColDimension();

        checkMatrixRowAndColForMultiply(ohter.getRowDimension());
        //结果集 A行 B 列
        double[][] out = new double[nRows][newCols];
        // A的列 与B列 行相乘 的和 ,等于 每行每列

        for (int row = 0; row < nRows; row++) {
            for (int col = 0; col < newCols; col++) {
                double sum = 0.0D;
                for (int sumCol = 0; sumCol < nCols; sumCol++) {
                    sum += this.getData()[row][sumCol] * ohter.getData()[sumCol][col];
                }
                out[row][col] = sum;
            }
        }
        this.setData(out);
        return this;

    }

    @Override
    public Matrix multiply(double num) {
        return this.operationByScalar(MULTIPLY_OPERATER, num);
    }

    @Override
    public void addCol(double value) {
        this.addCol(new double[]{value});
    }

    @Override
    public void addCol(double[] values) {
        double[][] elementData = this.getData();
        for (int i = 0, nRows = elementData.length; i < nRows; i++) {
            int currentColIndex = elementData[i].length;
            int newColLength = currentColIndex + values.length;
            elementData[i] = Arrays.copyOf(elementData[i], newColLength);
            for (int colIndex = currentColIndex, valuesIndex = 0; colIndex < newColLength; colIndex++, valuesIndex++) {
                elementData[i][colIndex] = values[valuesIndex];
            }
        }
        this.setData(elementData);
    }

    @Override
    public void addRow(double value) {
        this.addRow(new double[]{value});
    }

    @Override
    public void addRow(double[] values) {

        double[][] elementData = this.getData();

        int nRows = this.getRowDimension();
        int newRowLength = nRows + values.length;
        int nColLength = this.getColDimension();

        elementData = Arrays.copyOf(elementData, newRowLength);

        for (int rowIndex = nRows, valuesIndex = 0; rowIndex < newRowLength; rowIndex++, valuesIndex++) {
            elementData[rowIndex] = new double[nColLength];
            for (int col = 0; col < nColLength; col++) {
                elementData[rowIndex][col] = values[valuesIndex];
            }
        }
        this.setData(elementData);
    }

    public abstract void setData(double[][] data);

    @Override
    public abstract double setValue(int rowIndex, int colIndex, double value);

    @Override
    public abstract double getValue(int rowIndex, int colIndex);

    protected static final BiFunction<Double, Double, Double> ADD_OPERATER = (a, b) -> a + b;

    protected static final BiFunction<Double, Double, Double> SUBTRACT_OPERATER = (a, b) -> a - b;

    protected static final BiFunction<Double, Double, Double> MULTIPLY_OPERATER = (a, b) -> a * b;

    @Override
    public int hashCode() {

        int hashCode = 1;
        int nRows = this.getRowDimension();
        int nCols = this.getColDimension();
        for (int row = 0; row < nRows; ++row) {
            for (int col = 0; col < nCols; ++col) {
                hashCode = 31 * hashCode + new Double(this.getValue(row, col)).hashCode();
            }
        }
        return hashCode;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        } else if (obj instanceof Matrix) {
            return false;
        } else {
            Matrix otherMatrix = (Matrix) obj;
            int rowSize = this.getRowDimension();
            int colSize = this.getColDimension();
            if (otherMatrix.getRowDimension() == rowSize && otherMatrix.getColDimension() == colSize) {
                for (int row = 0; row < rowSize; ++row) {
                    for (int col = 0; col < colSize; ++col) {
                        if (this.getValue(row, col) != otherMatrix.getValue(row, col)) {
                            return false;
                        }
                    }
                }
            }
        }
        return true;
    }

    protected void checkMatrixRowAndColForMultiply(int bRow) {
        try {
            checkMatrixColSize(bRow);
        } catch (BizRunTimeException e) {
            throw new BizRunTimeException("乘法运算,列与行不相等");
        }
    }

    protected void checkMatrixRowAndColEquase(Matrix matrix) {
        checkMatrixColSize(matrix.getColDimension());
        checkMatrixRowSize(matrix.getRowDimension());
    }

    private void checkMatrixRowSize(int row) {
        if (this.getRowDimension() != row) {
            throw new BizRunTimeException("数组运算行不相等");
        }
    }

    private void checkMatrixColSize(int col) {
        if (this.getColDimension() != col) {
            throw new BizRunTimeException("数组运算列不相等");
        }
    }

    protected void checkMatrixIndex(int row, int col) {
        checkColIndex(col);
        checkRowIndex(row);
    }

    private void checkRowIndex(int row) {
        if (row < 0 || row > this.getRowDimension()) {
            throw new BizRunTimeException("行下标越界");
        }
    }

    private void checkColIndex(int col) {
        if (col < 0 || col > this.getColDimension()) {
            throw new BizRunTimeException("列下标越界");
        }
    }

    @Override
    public String toString() {
        return baseToString(false);
    }

    @Override
    public String toStringByStand() {
        return baseToString(true);
    }

    protected String baseToString(boolean standFlag) {
        int rowSize = this.getRowDimension();
        if (rowSize < 1) {
            return "";
        }
        String rowDelimiter = standFlag ? ",\n" : ",";
        return this.stream()
                .map(doubleStream -> doubleStream.boxed().map(String::valueOf)
                        .collect(Collectors.joining(",", "{", "}"))
                ).collect(Collectors.joining(rowDelimiter, "[", "]"));
    }
}

  矩阵实现类代码

/**
 * 数组实现二维矩阵
 * QQ群:528344775
 *
 * @author chengchaochao
 * @create 2018-01-04 12:44
 **/
public class Array2DMatrix extends AbstractMatrix {

    private double [][] data;

    public Array2DMatrix(double[][] data) {
        this(data.length, data[0].length);
        this.copyDataIn(data);
    }

    /***
     *
     * @param flag true 表示 列, false 表示行
     */
    public Array2DMatrix(double[] data ,boolean flag) {
        if(flag){
            int nRows = data.length;
            this.data = new double[nRows][1];

            for (int row = 0; row < nRows; ++row) {
                this.data[row][0] = data[row];
            }
        }else{
            int nCols = data.length;
            this.data = new double[1][nCols];
            for(int col = 0; col <nCols; ++col){
                this.data[0][col] = data[col];
            }

        }
    }

    public Array2DMatrix(double[] data ){
        this(data,false);
    }

    public Array2DMatrix(int rowDimension) {
       this(rowDimension,rowDimension);
    }

    public Array2DMatrix(int rowDimension, int columnDimension) {
        super(rowDimension, columnDimension);
        this.data = new double[rowDimension][columnDimension];
    }

    public Array2DMatrix(int rowDimension, int columnDimension, double defaultValue) {
        this(rowDimension,columnDimension);
        this.data = this.add(defaultValue).getData();
    }

    public int getRowDimension() {
        return this.data == null ? 0 : this.data.length;
    }

    public int getColDimension() {
        return this.data != null && this.data[0] != null ? this.data[0].length : 0;
    }


    @Override
    public double[][] getData() {
        return this.copyDataOut();
    }

    @Override
    public Matrix createMatrix(int rowSize, int colSize) {
        return new Array2DMatrix(rowSize, colSize);
    }

    @Override
    public Matrix copy() {
        return new Array2DMatrix(this.copyDataOut());
    }

    @Override
    public double setValue(int rowIndex, int colIndex, double value) {
        checkMatrixIndex(rowIndex, colIndex);
        return this.data[rowIndex][colIndex] = value;
    }

    @Override
    public double getValue(int rowIndex, int colIndex) {
        checkMatrixIndex(rowIndex, colIndex);
        return this.data[rowIndex][colIndex];
    }

    public void addToValue(int rowIndex, int colIndex, double value) {
        checkMatrixIndex(rowIndex, colIndex);
        this.data[rowIndex][colIndex] += value;
    }

    public void subtractToValue(int rowIndex, int colIndex, double value) {
        checkMatrixIndex(rowIndex, colIndex);
        this.data[rowIndex][colIndex] -= value;
    }

    public void multiplyToValue(int rowIndex, int colIndex, double value) {
        checkMatrixIndex(rowIndex, colIndex);
        this.data[rowIndex][colIndex] *= value;
    }


    private double[][] copyDataOut() {
        int rowSize = this.getRowDimension();
        double[][] out = new double[rowSize][this.getColDimension()];

        for (int i = 0; i < rowSize; ++i) {

            System.arraycopy(this.data[i], 0, out[i], 0, this.data[i].length);
        }

        return out;
    }

    public void setData(double[][] data) {
        this.data = data;
    }

    private void copyDataIn(double[][] data) {
        Array2DMatrix otherMatrix = new Array2DMatrix(data.length, data[0].length);
        otherMatrix.setData(data);
        this.data = this.add(otherMatrix).getData();
    }

    @Override
    public Matrix createMatrix(double[][] data) {
        return new Array2DMatrix(data);
    }
}
矩阵乘以向量
例子:计算

    @Test
    public void testMatrixMultiplyVector(){
    		/***
    		 * 创建矩阵A [{2.0,3.0,4.0},
			 *		    {5.0,8.0,2.0}]
    		 */
    	  Matrix matrixA = Matrix.createArray2DMatrix(2,3);
          matrixA.setValue(0,0,2);
          matrixA.setValue(0,1,3);
          matrixA.setValue(0,2,4);
          matrixA.setValue(1,0,5);
          matrixA.setValue(1,1,8);
          matrixA.setValue(1,2,2);
          System.out.println(matrixA.toStringByStand());
          /**
           * 创建矩阵B
           * [{2.0},
           *  {1.0},
           *  {6.0}]
           */
          Matrix matrixB = new Array2DMatrix(new double[]{2,1,6},true);// true 表示用传入的向量 创建 一个 只有一列的矩阵 
          System.out.println(matrixB.toStringByStand());
         
          /**
           * 计算并输出
           * 结果为:
           * [{31.0},
		   *  {30.0}]
           */
          System.out.println(matrixA.matrixMultiply(matrixB).toStringByStand());
    }
结果为: 一个 两行一列的 矩阵 

这里是首先将向量转换为矩阵 然后相乘,实际上就是 矩阵和矩阵的乘积,矩阵与矩阵就不再举例。

矩阵转置

矩阵转置例子

 @Test
    public void testTransposition(){
        Matrix matrixB = Matrix.createArray2DMatrix(3,3);
        matrixB.setValue(0,0,1);
        matrixB.setValue(0,1,2);
        matrixB.setValue(0,2,3);
        matrixB.setValue(1,0,4);
        matrixB.setValue(1,1,5);
        matrixB.setValue(1,2,6);
        matrixB.setValue(2,0,1.1D);
        matrixB.setValue(2,1,1.1D);
        matrixB.setValue(2,2,1.1D);

        System.out.println("矩阵B转置前===="+matrixB.toStringByStand());

        System.out.println("矩阵B转置后===="+matrixB.transposition().toStringByStand());

        Matrix matrixA = Matrix.createArray2DMatrix(2,3);
        matrixA.setValue(0,0,1);
        matrixA.setValue(0,1,2);
        matrixA.setValue(0,2,3);
        matrixA.setValue(1,0,4);
        matrixA.setValue(1,1,5);
        matrixA.setValue(1,2,6);

        System.out.println("矩阵A转置前===="+matrixA.toStringByStand());

        System.out.println("矩阵A转置后===="+matrixA.transposition().toStringByStand());


        Matrix matrixC = Matrix.createArray2DMatrix(3,2);
        matrixC.setValue(0,0,1);
        matrixC.setValue(0,1,2);

        matrixC.setValue(1,0,4);
        matrixC.setValue(1,1,5);

        matrixC.setValue(2,0,1.1D);
        matrixC.setValue(2,1,1.1D);


        System.out.println("矩阵C转置前===="+matrixC.toStringByStand());

        System.out.println("矩阵C转置后===="+matrixC.transposition().toStringByStand());


        Matrix matrixD = new Array2DMatrix(new double[]{1,2,3},true);

        System.out.println("矩阵D转置前===="+matrixD.toStringByStand());

        System.out.println("矩阵D转置后===="+matrixD.transposition().toStringByStand());

        Matrix matrixE = new Array2DMatrix(new double[]{1,2,3});

        System.out.println("矩阵E转置前===="+matrixE.toStringByStand());

        System.out.println("矩阵E转置后===="+matrixE.transposition().toStringByStand());


    }
上面代码运算执行结果

矩阵B转置前====[{1.0,2.0,3.0},
      {4.0,5.0,6.0},
      {1.1,1.1,1.1}]
矩阵B转置后====[{1.0,4.0,1.1},
      {2.0,5.0,1.1},
      {3.0,6.0,1.1}]
矩阵A转置前====[{1.0,2.0,3.0},
      {4.0,5.0,6.0}]
矩阵A转置后====[{1.0,4.0},
      {2.0,5.0},
      {3.0,6.0}]
矩阵C转置前====[{1.0,2.0},
      {4.0,5.0},
      {1.1,1.1}]
矩阵C转置后====[{1.0,4.0,1.1},
      {2.0,5.0,1.1}]
矩阵D转置前====[{1.0},
                              {2.0},
      {3.0}]
矩阵D转置后====[{1.0,2.0,3.0}]
矩阵E转置前====[{1.0,2.0,3.0}]
矩阵E转置后====[{1.0},
      {2.0},
      {3.0}]

第一次写博客,有错误的地方 ,希望大家能够指正,谢谢~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值