最近打算学习 机器学习,发现里面算法矩阵用的挺多的,然后就学习了下矩阵,在这里分享一下学习心得
基本结构
- m x n 的矩阵一般写成:
- 矩阵与向量相乘: 其实 可以看做矩阵与只有一列或者一行的 矩阵相乘。如 :矩阵 A[n][m] x B[m][1] = AB[n][1]
- 矩阵与矩阵的乘积: A[m][n] xB[n,k] = AB[m][k]
- 将矩阵B视看做由k个列向量组成;
- 将矩阵A分别与向量b[i] 计算乘积,将结果依次作为矩阵AB的 第i列;
矩阵接口实现代码
/**
* 矩阵接口
* QQ群:528344775
*
* @author chengchaochao
* @create 2017-12-29 10:39
**/
public interface Matrix {
/**
* 返回矩阵是否为方阵
* <ul>
* <li>true</li>
* <li>false</li>
* </ul>
*/
boolean isSquare();
/**
* 返回矩阵行大小
*/
int getRowDimension();
/**
* 返回矩阵列大小
*/
int getColDimension();
double [] [] getData();
Matrix createMatrix(int rowSize, int colSize);
/**
* 创建默认值为0 的指定大小的Array2DMatrix
*/
static Matrix createArray2DMatrix(int rowSize, int colSize){
return new Array2DMatrix(rowSize,colSize,0.0D);
}
/**
* 创建指定默认值 的指定大小的Array2DMatrix
*/
static Matrix createArray2DMatrix(int rowSize, int colSize ,double defaultValue){
return new Array2DMatrix(rowSize,colSize,defaultValue);
}
/**
*完全复制
*/
Matrix copy();
/**
*两个矩阵 相加(每个元素相加)
*/
Matrix add(Matrix ohter);
/**
*两个矩阵 相加(每个元素与某数相加)
*/
Matrix add(double num);
/**
*两个矩阵 相减(每个元素相减)
*/
Matrix subtract(Matrix ohter);
/**
*两个矩阵 相减(每个元素与某数相减)
*/
Matrix subtract(double num);
/**
*两个矩阵 相乘(每个元素相乘)
*/
Matrix valueMultiply(Matrix ohter);
/**
*两个矩阵 相乘
* 矩阵 A ,B 必须 A 的行 等于 B 的列
*/
Matrix matrixMultiply(Matrix ohter);
/**
*两个矩阵 相乘(每个元素与某数相乘)
*/
Matrix multiply(double num);
double setValue(int rowIndex ,int colIndex, double value);
double getValue(int rowIndex ,int colIndex);
/***
* 追加列
*/
void addCol(double value);
/***
* 追加多列
*/
void addCol(double [] values);
/**
* 追加行
*/
void addRow(double value);
/**
* 追加多行
*/
void addRow(double [] values);
/**
* 矩阵转置
*/
Matrix transposition();
default Stream<DoubleStream> stream(){
return Arrays.stream(getData()).map(DoubleStream::of);
}
String toStringByStand();
矩阵抽象类实现代码
/**
* 矩阵主实现
* QQ群:528344775
*
* @author chengchaochao
* @create 2018-01-04 8:54
**/
public abstract class AbstractMatrix implements Matrix {
protected AbstractMatrix(int rowDimension, int columnDimension) {
if (rowDimension < 1) {
throw new BizRunTimeException("行尺寸错误,不能小于1");
} else if (columnDimension < 1) {
throw new BizRunTimeException("列尺寸错误,不能小于1");
}
}
public AbstractMatrix() {
}
@Override
public boolean isSquare() {
return this.getColDimension() == this.getRowDimension();
}
@Override
public abstract int getRowDimension();
@Override
public abstract int getColDimension();
@Override
public abstract Matrix createMatrix(int rowSize, int colSize);
public abstract Matrix createMatrix(double[][] data);
@Override
public abstract Matrix copy();
/**
* 矩阵 与 矩阵间元素 运算
*
* @Author: chengchaochao
* @Date: 2018/1/4 13:08
*/
protected Matrix operation(Matrix ohter, BiFunction<Double, Double, Double> operater) {
int rowCount = this.getRowDimension();
int columnCount = this.getColDimension();
// 检查行和列
checkMatrixRowAndColEquase(ohter);
Matrix result = this.createMatrix(rowCount, columnCount);
for (int row = 0; row < rowCount; ++row) {
for (int col = 0; col < columnCount; ++col) {
result.setValue(row, col, operater.apply(this.getValue(row, col), ohter.getValue(row, col)));
}
}
return result;
}
/**
* 矩阵 与某数 间的 运算
*
*/
protected Matrix operationByScalar(BiFunction<Double, Double, Double> operater, double num) {
// 不用检查行和列
int rowCount = this.getRowDimension();
int columnCount = this.getColDimension();
Matrix result = this.createMatrix(rowCount, columnCount);
for (int row = 0; row < rowCount; ++row) {
for (int col = 0; col < columnCount; ++col) {
result.setValue(row, col, operater.apply(this.getValue(row, col), num));
}
}
return result;
}
@Override
public Matrix add(Matrix ohter) {
return this.operation(ohter, ADD_OPERATER);
}
@Override
public Matrix add(double num) {
return this.operationByScalar(ADD_OPERATER, num);
}
@Override
public Matrix subtract(Matrix ohter) {
return this.operation(ohter, SUBTRACT_OPERATER);
}
@Override
public Matrix subtract(double num) {
return this.operationByScalar(SUBTRACT_OPERATER, num);
}
@Override
public Matrix valueMultiply(Matrix ohter) {
return this.operation(ohter, MULTIPLY_OPERATER);
}
/**
* 矩阵转置 , 行变 列 , 列变 行
*
* @see <a>https://en.wikipedia.org/wiki/Transpose</a>
*/
@Override
public Matrix transposition() {
int nCols = this.getColDimension();
int nRows = this.getRowDimension();
double[][] newData = new double[nCols][nRows];
for (int row = 0; row < nCols; row++) {
for (int col = 0; col < nRows; col++) {
newData[row][col] = this.getValue(col, row);
}
}
this.setData(newData);
return this;
}
/**
* <pre>
* 当矩阵A的列数等于矩阵B的行数时,A与B可以相乘。
* 矩阵C的行数等于矩阵A的行数,C的列数等于B的列数。
* 乘积C的第m行第n列的元素等于矩阵A的第m行的元素与矩阵B的第n列对应元素乘积之和。
* </pre>
*
* @Author: chengchaochao
* @Date: 2018/1/5 9:16
*/
@Override
public Matrix matrixMultiply(Matrix ohter) {
int nRows = this.getRowDimension();
int nCols = this.getColDimension();
int newCols = ohter.getColDimension();
checkMatrixRowAndColForMultiply(ohter.getRowDimension());
//结果集 A行 B 列
double[][] out = new double[nRows][newCols];
// A的列 与B列 行相乘 的和 ,等于 每行每列
for (int row = 0; row < nRows; row++) {
for (int col = 0; col < newCols; col++) {
double sum = 0.0D;
for (int sumCol = 0; sumCol < nCols; sumCol++) {
sum += this.getData()[row][sumCol] * ohter.getData()[sumCol][col];
}
out[row][col] = sum;
}
}
this.setData(out);
return this;
}
@Override
public Matrix multiply(double num) {
return this.operationByScalar(MULTIPLY_OPERATER, num);
}
@Override
public void addCol(double value) {
this.addCol(new double[]{value});
}
@Override
public void addCol(double[] values) {
double[][] elementData = this.getData();
for (int i = 0, nRows = elementData.length; i < nRows; i++) {
int currentColIndex = elementData[i].length;
int newColLength = currentColIndex + values.length;
elementData[i] = Arrays.copyOf(elementData[i], newColLength);
for (int colIndex = currentColIndex, valuesIndex = 0; colIndex < newColLength; colIndex++, valuesIndex++) {
elementData[i][colIndex] = values[valuesIndex];
}
}
this.setData(elementData);
}
@Override
public void addRow(double value) {
this.addRow(new double[]{value});
}
@Override
public void addRow(double[] values) {
double[][] elementData = this.getData();
int nRows = this.getRowDimension();
int newRowLength = nRows + values.length;
int nColLength = this.getColDimension();
elementData = Arrays.copyOf(elementData, newRowLength);
for (int rowIndex = nRows, valuesIndex = 0; rowIndex < newRowLength; rowIndex++, valuesIndex++) {
elementData[rowIndex] = new double[nColLength];
for (int col = 0; col < nColLength; col++) {
elementData[rowIndex][col] = values[valuesIndex];
}
}
this.setData(elementData);
}
public abstract void setData(double[][] data);
@Override
public abstract double setValue(int rowIndex, int colIndex, double value);
@Override
public abstract double getValue(int rowIndex, int colIndex);
protected static final BiFunction<Double, Double, Double> ADD_OPERATER = (a, b) -> a + b;
protected static final BiFunction<Double, Double, Double> SUBTRACT_OPERATER = (a, b) -> a - b;
protected static final BiFunction<Double, Double, Double> MULTIPLY_OPERATER = (a, b) -> a * b;
@Override
public int hashCode() {
int hashCode = 1;
int nRows = this.getRowDimension();
int nCols = this.getColDimension();
for (int row = 0; row < nRows; ++row) {
for (int col = 0; col < nCols; ++col) {
hashCode = 31 * hashCode + new Double(this.getValue(row, col)).hashCode();
}
}
return hashCode;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
} else if (obj instanceof Matrix) {
return false;
} else {
Matrix otherMatrix = (Matrix) obj;
int rowSize = this.getRowDimension();
int colSize = this.getColDimension();
if (otherMatrix.getRowDimension() == rowSize && otherMatrix.getColDimension() == colSize) {
for (int row = 0; row < rowSize; ++row) {
for (int col = 0; col < colSize; ++col) {
if (this.getValue(row, col) != otherMatrix.getValue(row, col)) {
return false;
}
}
}
}
}
return true;
}
protected void checkMatrixRowAndColForMultiply(int bRow) {
try {
checkMatrixColSize(bRow);
} catch (BizRunTimeException e) {
throw new BizRunTimeException("乘法运算,列与行不相等");
}
}
protected void checkMatrixRowAndColEquase(Matrix matrix) {
checkMatrixColSize(matrix.getColDimension());
checkMatrixRowSize(matrix.getRowDimension());
}
private void checkMatrixRowSize(int row) {
if (this.getRowDimension() != row) {
throw new BizRunTimeException("数组运算行不相等");
}
}
private void checkMatrixColSize(int col) {
if (this.getColDimension() != col) {
throw new BizRunTimeException("数组运算列不相等");
}
}
protected void checkMatrixIndex(int row, int col) {
checkColIndex(col);
checkRowIndex(row);
}
private void checkRowIndex(int row) {
if (row < 0 || row > this.getRowDimension()) {
throw new BizRunTimeException("行下标越界");
}
}
private void checkColIndex(int col) {
if (col < 0 || col > this.getColDimension()) {
throw new BizRunTimeException("列下标越界");
}
}
@Override
public String toString() {
return baseToString(false);
}
@Override
public String toStringByStand() {
return baseToString(true);
}
protected String baseToString(boolean standFlag) {
int rowSize = this.getRowDimension();
if (rowSize < 1) {
return "";
}
String rowDelimiter = standFlag ? ",\n" : ",";
return this.stream()
.map(doubleStream -> doubleStream.boxed().map(String::valueOf)
.collect(Collectors.joining(",", "{", "}"))
).collect(Collectors.joining(rowDelimiter, "[", "]"));
}
}
矩阵实现类代码
/**
* 数组实现二维矩阵
* QQ群:528344775
*
* @author chengchaochao
* @create 2018-01-04 12:44
**/
public class Array2DMatrix extends AbstractMatrix {
private double [][] data;
public Array2DMatrix(double[][] data) {
this(data.length, data[0].length);
this.copyDataIn(data);
}
/***
*
* @param flag true 表示 列, false 表示行
*/
public Array2DMatrix(double[] data ,boolean flag) {
if(flag){
int nRows = data.length;
this.data = new double[nRows][1];
for (int row = 0; row < nRows; ++row) {
this.data[row][0] = data[row];
}
}else{
int nCols = data.length;
this.data = new double[1][nCols];
for(int col = 0; col <nCols; ++col){
this.data[0][col] = data[col];
}
}
}
public Array2DMatrix(double[] data ){
this(data,false);
}
public Array2DMatrix(int rowDimension) {
this(rowDimension,rowDimension);
}
public Array2DMatrix(int rowDimension, int columnDimension) {
super(rowDimension, columnDimension);
this.data = new double[rowDimension][columnDimension];
}
public Array2DMatrix(int rowDimension, int columnDimension, double defaultValue) {
this(rowDimension,columnDimension);
this.data = this.add(defaultValue).getData();
}
public int getRowDimension() {
return this.data == null ? 0 : this.data.length;
}
public int getColDimension() {
return this.data != null && this.data[0] != null ? this.data[0].length : 0;
}
@Override
public double[][] getData() {
return this.copyDataOut();
}
@Override
public Matrix createMatrix(int rowSize, int colSize) {
return new Array2DMatrix(rowSize, colSize);
}
@Override
public Matrix copy() {
return new Array2DMatrix(this.copyDataOut());
}
@Override
public double setValue(int rowIndex, int colIndex, double value) {
checkMatrixIndex(rowIndex, colIndex);
return this.data[rowIndex][colIndex] = value;
}
@Override
public double getValue(int rowIndex, int colIndex) {
checkMatrixIndex(rowIndex, colIndex);
return this.data[rowIndex][colIndex];
}
public void addToValue(int rowIndex, int colIndex, double value) {
checkMatrixIndex(rowIndex, colIndex);
this.data[rowIndex][colIndex] += value;
}
public void subtractToValue(int rowIndex, int colIndex, double value) {
checkMatrixIndex(rowIndex, colIndex);
this.data[rowIndex][colIndex] -= value;
}
public void multiplyToValue(int rowIndex, int colIndex, double value) {
checkMatrixIndex(rowIndex, colIndex);
this.data[rowIndex][colIndex] *= value;
}
private double[][] copyDataOut() {
int rowSize = this.getRowDimension();
double[][] out = new double[rowSize][this.getColDimension()];
for (int i = 0; i < rowSize; ++i) {
System.arraycopy(this.data[i], 0, out[i], 0, this.data[i].length);
}
return out;
}
public void setData(double[][] data) {
this.data = data;
}
private void copyDataIn(double[][] data) {
Array2DMatrix otherMatrix = new Array2DMatrix(data.length, data[0].length);
otherMatrix.setData(data);
this.data = this.add(otherMatrix).getData();
}
@Override
public Matrix createMatrix(double[][] data) {
return new Array2DMatrix(data);
}
}
矩阵乘以向量
例子:计算
@Test
public void testMatrixMultiplyVector(){
/***
* 创建矩阵A [{2.0,3.0,4.0},
* {5.0,8.0,2.0}]
*/
Matrix matrixA = Matrix.createArray2DMatrix(2,3);
matrixA.setValue(0,0,2);
matrixA.setValue(0,1,3);
matrixA.setValue(0,2,4);
matrixA.setValue(1,0,5);
matrixA.setValue(1,1,8);
matrixA.setValue(1,2,2);
System.out.println(matrixA.toStringByStand());
/**
* 创建矩阵B
* [{2.0},
* {1.0},
* {6.0}]
*/
Matrix matrixB = new Array2DMatrix(new double[]{2,1,6},true);// true 表示用传入的向量 创建 一个 只有一列的矩阵
System.out.println(matrixB.toStringByStand());
/**
* 计算并输出
* 结果为:
* [{31.0},
* {30.0}]
*/
System.out.println(matrixA.matrixMultiply(matrixB).toStringByStand());
}
结果为: 一个 两行一列的 矩阵
这里是首先将向量转换为矩阵 然后相乘,实际上就是 矩阵和矩阵的乘积,矩阵与矩阵就不再举例。
矩阵转置
矩阵转置例子
@Test
public void testTransposition(){
Matrix matrixB = Matrix.createArray2DMatrix(3,3);
matrixB.setValue(0,0,1);
matrixB.setValue(0,1,2);
matrixB.setValue(0,2,3);
matrixB.setValue(1,0,4);
matrixB.setValue(1,1,5);
matrixB.setValue(1,2,6);
matrixB.setValue(2,0,1.1D);
matrixB.setValue(2,1,1.1D);
matrixB.setValue(2,2,1.1D);
System.out.println("矩阵B转置前===="+matrixB.toStringByStand());
System.out.println("矩阵B转置后===="+matrixB.transposition().toStringByStand());
Matrix matrixA = Matrix.createArray2DMatrix(2,3);
matrixA.setValue(0,0,1);
matrixA.setValue(0,1,2);
matrixA.setValue(0,2,3);
matrixA.setValue(1,0,4);
matrixA.setValue(1,1,5);
matrixA.setValue(1,2,6);
System.out.println("矩阵A转置前===="+matrixA.toStringByStand());
System.out.println("矩阵A转置后===="+matrixA.transposition().toStringByStand());
Matrix matrixC = Matrix.createArray2DMatrix(3,2);
matrixC.setValue(0,0,1);
matrixC.setValue(0,1,2);
matrixC.setValue(1,0,4);
matrixC.setValue(1,1,5);
matrixC.setValue(2,0,1.1D);
matrixC.setValue(2,1,1.1D);
System.out.println("矩阵C转置前===="+matrixC.toStringByStand());
System.out.println("矩阵C转置后===="+matrixC.transposition().toStringByStand());
Matrix matrixD = new Array2DMatrix(new double[]{1,2,3},true);
System.out.println("矩阵D转置前===="+matrixD.toStringByStand());
System.out.println("矩阵D转置后===="+matrixD.transposition().toStringByStand());
Matrix matrixE = new Array2DMatrix(new double[]{1,2,3});
System.out.println("矩阵E转置前===="+matrixE.toStringByStand());
System.out.println("矩阵E转置后===="+matrixE.transposition().toStringByStand());
}
上面代码运算执行结果
矩阵B转置前====[{1.0,2.0,3.0},
{4.0,5.0,6.0},
{1.1,1.1,1.1}]
矩阵B转置后====[{1.0,4.0,1.1},
{2.0,5.0,1.1},
{3.0,6.0,1.1}]
矩阵A转置前====[{1.0,2.0,3.0},
{4.0,5.0,6.0}]
矩阵A转置后====[{1.0,4.0},
{2.0,5.0},
{3.0,6.0}]
矩阵C转置前====[{1.0,2.0},
{4.0,5.0},
{1.1,1.1}]
矩阵C转置后====[{1.0,4.0,1.1},
{2.0,5.0,1.1}]
矩阵D转置前====[{1.0},
{2.0},
{3.0}]
矩阵D转置后====[{1.0,2.0,3.0}]
矩阵E转置前====[{1.0,2.0,3.0}]
矩阵E转置后====[{1.0},
{2.0},
{3.0}]
第一次写博客,有错误的地方 ,希望大家能够指正,谢谢~