CMU15-445 2021Fall Project0-C++ primer
In this project, you will implement three classes: Matrix
, RowMatrix
, and RowMatrixOperations
. These matrices are simple two-dimensional matrices that must support addition, matrix multiplication, and a simplified General Matrix Multiply (GEMM) operation.
You will only need to modify a single file: p0_starter.h
You can find the file in the BusTub repository at src/include/primer/p0_starter.h
.
In this header file, we define the three classes that you must implement. The Matrix
abstract class defines the common functions for the derived class RowMatrix
. The RowMatrixOperations
class uses RowMatrix
objects to achieve the operations mentioned in the overview above. The function prototypes and member variables are specified in the file. The project requires you to fill in the implementations of all the constructors, destructors, and member functions. Do not add any additional function prototypes or member variables. Your implementation should consist solely of implementing the functions that we have defined for you.
这个lab只需要在 src/include/primer/p0_starter.h
按照里面的代码提示一步一步写就完事。
Matrix
- 构造函数:完成赋值,对其指针开辟 空间并初始化
析构函数:delete
RowMatrix
- 在构造函数中给 data_ 分配一个行指针数组,然后初始化指针的值(注意要在Matrix的 linear_ 上构建)
- MatImport 函数会通过数组 arr 来设置矩阵对应的值
RowMatrixOperations
- 实现矩阵的加、乘、GEMM 操作
这里给出一个GEMM的示例
static std::unique_ptr<RowMatrix<T>> GEMM(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB,
const RowMatrix<T> *matrixC) {
// TODO(P0): Add implementation
auto gemm_result = Multiply(matrixA, matrixB);
if (gemm_result == nullptr) {
return nullptr;
}
int rows = matrixC->GetRowCount(), cols = matrixC->GetColumnCount();
if ( gemm_result->GetColumnCount() != matrixC->GetColumnCount() || gemm_result->GetRowCount() != matrixC->GetRowCount()) {
return std::unique_ptr<RowMatrix<T>>(nullptr);
}
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
gemm_result->SetElement(i, j, gemm_result->GetElement(i, j) + matrixC->GetElement(i, j));
}
}
return gemm_result;
}
-
测试
首先去掉 test/primer/starter_test.cpp 中的 DISABLED_前缀$ cd build $ make starter_test $ ./test/starter_test
Running main() from gmock_main.cc
[==========] Running 5 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 5 tests from StarterTest
[ RUN ] StarterTest.SampleTest
[ OK ] StarterTest.SampleTest (0 ms)
[ RUN ] StarterTest.InitializationTest
Exception Type :: Out of Range
Message :: OUT_OF_RANGE
Exception Type :: Out of Range
Message :: OUT_OF_RANGE
[ OK ] StarterTest.InitializationTest (0 ms)
[ RUN ] StarterTest.ElementAccessTest
Exception Type :: Out of Range
Message :: OUT_OF_RANGE
Exception Type :: Out of Range
Message :: OUT_OF_RANGE
Exception Type :: Out of Range
Message :: OUT_OF_RANGE
Exception Type :: Out of Range
Message :: OUT_OF_RANGE
Exception Type :: Out of Range
Message :: OUT_OF_RANGE
Exception Type :: Out of Range
Message :: OUT_OF_RANGE
Exception Type :: Out of Range
Message :: OUT_OF_RANGE
Exception Type :: Out of Range
Message :: OUT_OF_RANGE
[ OK ] StarterTest.ElementAccessTest (0 ms)
[ RUN ] StarterTest.AdditionTest
[ OK ] StarterTest.AdditionTest (0 ms)
[ RUN ] StarterTest.MultiplicationTest
[ OK ] StarterTest.MultiplicationTest (0 ms)
[----------] 5 tests from StarterTest (1 ms total)
[----------] Global test environment tear-down
[==========] 5 tests from 1 test suite ran. (1 ms total)
[ PASSED ] 5 tests.
成功