苹果深度学习框架MLX
简介及编译示例
MLX
简介
北京时间2023年12月6日,苹果机器学习研究中心 (Apple machine learning research) 在GitHub
上开源MLX
。
项目地址为 https://github.com/ml-explore/mlx。
MLX
深度学习框架是苹果专门为Apple Silicon
芯片优化,号称能够简化Mac
、iPad
、iPhone
平台研究人员设计和部署模型的过程。
MLX
的一些主要功能包括:
-
熟悉的 API
:
MLX
具有紧密类似NumPy
的Python API
。
MLX
还拥有功能齐全的C++ API
,它与Python API
非常相似。
MLX
具有像mlx.nn
和mlx.optimizers
这样的更高级别的软件包,
其API
紧密类似PyTorch
,用于简化构建更复杂的模型。 -
可组合函数转换
:
MLX
支持自动微分、自动矢量化和计算图优化等可组合函数的转换。 -
惰性计算
:
MLX
中的计算是惰性计算。数组只是在需要时生成。 -
动态图构建
:
MLX
中的计算图是动态构建的。更改函数参数的形状不会导致编译速度减慢,调试简单直观。 -
多设备
:
操作可以在任何支持的设备上运行 (目前是CPU
和GPU
)。 -
统一内存
:
MLX
和其他框架的显着区别是统一内存模型。MLX
中的数组位于共享内存中。
MLX
数组上的操作可以在任何受支持的设备类型上执行,而不需要传输数据。
MLX
是由苹果机器学习研究中心的机器学习研究人员为机器学习研究人员而设计的。
该框架旨在用户友好,但仍然高效训练和部署模型。
框架本身的设计也是概念上很简单。
目的是让研究人员能够轻松扩展和改进MLX
,以快速探索新想法为目标。
不能简单的把MLX
视为造轮子,苹果既然发布了GPU
,就自然的挖掘GPU
的运算潜力。
苹果的芯片架构与以往主流的芯片架构不同,是统一内存模型。
统一内存模型不同于TensorFlow
等框架需要显式管理内存,大大简化了编程模型。
编译MLX
示例
由于编译MLX
需要使用XCode
,因此不适合在容器中编译,需要直接在MacOS
中编译。
为便于管理电脑上下载的各种源代码,使用ghq
下载。
下载源代码
ghq get --shallow https://github.com/ml-explore/mlx
在访达
中浏览项目结构。
open -R "$(ghq list --full-path https://github.com/ml-explore/mlx)"
在命令行中浏览项目结构。
cd "$(ghq list --full-path https://github.com/ml-explore/mlx)" && ls -al
内容如下
total 168
drwxr-xr-x 26 huzhenghui staff 832 12 16 11:41 .
drwxr-xr-x 4 huzhenghui staff 128 12 16 08:09 ..
-rw-r--r--@ 1 huzhenghui staff 6148 12 16 08:25 .DS_Store
drwxr-xr-x 3 huzhenghui staff 96 12 16 08:00 .circleci
-rw-r--r-- 1 huzhenghui staff 2552 12 16 08:00 .clang-format
drwxr-xr-x 13 huzhenghui staff 416 12 16 08:00 .git
drwxr-xr-x 4 huzhenghui staff 128 12 16 08:00 .github
-rw-r--r-- 1 huzhenghui staff 733 12 16 08:00 .gitignore
-rw-r--r-- 1 huzhenghui staff 433 12 16 08:00 .pre-commit-config.yaml
-rw-r--r-- 1 huzhenghui staff 12320 12 16 08:00 ACKNOWLEDGMENTS.md
-rw-r--r-- 1 huzhenghui staff 6533 12 16 08:00 CMakeLists.txt
-rw-r--r-- 1 huzhenghui staff 5544 12 16 08:00 CODE_OF_CONDUCT.md
-rw-r--r-- 1 huzhenghui staff 1292 12 16 08:00 CONTRIBUTING.md
-rw-r--r-- 1 huzhenghui staff 1066 12 16 08:00 LICENSE
-rw-r--r-- 1 huzhenghui staff 69 12 16 08:00 MANIFEST.in
-rw-r--r-- 1 huzhenghui staff 3523 12 16 08:00 README.md
drwxr-xr-x 5 huzhenghui staff 160 12 16 08:00 benchmarks
drwxr-xr-x 3 huzhenghui staff 96 12 16 08:00 cmake
drwxr-xr-x 9 huzhenghui staff 288 12 16 08:00 docs
drwxr-xr-x 5 huzhenghui staff 160 12 16 08:00 examples
drwxr-xr-x 35 huzhenghui staff 1120 12 16 08:00 mlx
-rw-r--r-- 1 huzhenghui staff 1364 12 16 08:00 mlx.pc.in
-rw-r--r-- 1 huzhenghui staff 118 12 16 08:00 pyproject.toml
drwxr-xr-x 6 huzhenghui staff 192 12 16 08:00 python
-rw-r--r-- 1 huzhenghui staff 6887 12 16 08:00 setup.py
drwxr-xr-x 21 huzhenghui staff 672 12 16 08:00 tests
创建./build
文件夹用于构建。
cd "$(ghq list --full-path https://github.com/ml-explore/mlx)" && mkdir -p build
运行CMake
,为编译示例,需要设置MLX_BUILD_EXAMPLES
环境变量。
export MLX_BUILD_EXAMPLES=ON
cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && cmake ..
输出如下内容
-- The CXX compiler identification is AppleClang 15.0.0.15000100
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Building MLX for arm64 processor on Darwin
-- Building METAL sources
-- Building with SDK for macOS version 14.2
-- Accelerate found /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.2.sdk/System/Library/Frameworks/Accelerate.framework
CMake Deprecation Warning at build/_deps/doctest-src/CMakeLists.txt:1 (cmake_minimum_required):
Compatibility with CMake < 3.5 will be removed from a future version of
CMake.
Update the VERSION argument <min> value or use a ...<max> suffix to tell
CMake that the project does not need compatibility with older versions.
-- Configuring done (24.2s)
-- Generating done (0.1s)
-- Build files have been written to: /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build
运行make
编译,为提升速度,使用--jobs
并行编译。
cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && make --jobs
运行结果如下。
[ 1%] Building unary.air
[ 3%] Building reduce.air
[ 4%] Building indexing.air
[ 5%] Building arange.air
[ 6%] Building sort.air
[ 6%] Building softmax.air
[ 9%] Building copy.air
[ 9%] Building random.air
[ 10%] Building scan.air
[ 11%] Building gemv.air
[ 12%] Building conv.air
[ 13%] Building gemm.air
[ 15%] Building binary.air
[ 16%] Building arg_reduce.air
[ 17%] Building mlx.metallib
[ 17%] Built target mlx-metallib
[ 18%] Building CXX object CMakeFiles/mlx.dir/mlx/device.cpp.o
[ 19%] Building CXX object CMakeFiles/mlx.dir/mlx/allocator.cpp.o
[ 22%] Building CXX object CMakeFiles/mlx.dir/mlx/scheduler.cpp.o
[ 22%] Building CXX object CMakeFiles/mlx.dir/mlx/graph_utils.cpp.o
[ 23%] Building CXX object CMakeFiles/mlx.dir/mlx/transforms.cpp.o
[ 25%] Building CXX object CMakeFiles/mlx.dir/mlx/primitives.cpp.o
[ 25%] Building CXX object CMakeFiles/mlx.dir/mlx/fft.cpp.o
[ 26%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/reduce.cpp.o
[ 27%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/fft.cpp.o
[ 29%] Building CXX object CMakeFiles/mlx.dir/mlx/load.cpp.o
[ 30%] Building CXX object CMakeFiles/mlx.dir/mlx/array.cpp.o
[ 32%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/conv.cpp.o
[ 32%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/indexing.cpp.o
[ 33%] Building CXX object CMakeFiles/mlx.dir/mlx/dtype.cpp.o
[ 36%] Building CXX object CMakeFiles/mlx.dir/mlx/ops.cpp.o
[ 36%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/binary.cpp.o
[ 37%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/scan.cpp.o
[ 38%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/erf.cpp.o
[ 39%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/conv.cpp.o
[ 41%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/primitives.cpp.o
[ 41%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/copy.cpp.o
[ 43%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/metal.cpp.o
[ 44%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/reduce.cpp.o
[ 45%] Building CXX object CMakeFiles/mlx.dir/mlx/utils.cpp.o
[ 46%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/softmax.cpp.o
[ 47%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/copy.cpp.o
[ 48%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/arg_reduce.cpp.o
[ 50%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/primitives.cpp.o
[ 51%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/primitives.cpp.o
[ 52%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/load.cpp.o
[ 53%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/conv.cpp.o
[ 55%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/matmul.cpp.o
[ 55%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/device.cpp.o
[ 58%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/allocator.cpp.o
[ 58%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/fft.cpp.o
[ 59%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/threefry.cpp.o
[ 60%] Building CXX object CMakeFiles/mlx.dir/mlx/random.cpp.o
[ 61%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/sort.cpp.o
[ 63%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/softmax.cpp.o
[ 63%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/scan.cpp.o
[ 65%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/matmul.cpp.o
[ 66%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/indexing.cpp.o
[ 67%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/reduce.cpp.o
[ 68%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/sort.cpp.o
[ 69%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/softmax.cpp.o
[ 70%] Linking CXX static library libmlx.a
[ 70%] Built target mlx
[ 74%] Building CXX object examples/cpp/CMakeFiles/linear_regression.dir/linear_regression.cpp.o
[ 74%] Building CXX object examples/cpp/CMakeFiles/tutorial.dir/tutorial.cpp.o
[ 74%] Building CXX object tests/CMakeFiles/tests.dir/tests.cpp.o
[ 76%] Building CXX object tests/CMakeFiles/tests.dir/array_tests.cpp.o
[ 76%] Building CXX object tests/CMakeFiles/tests.dir/allocator_tests.cpp.o
[ 77%] Building CXX object examples/cpp/CMakeFiles/logistic_regression.dir/logistic_regression.cpp.o
[ 79%] Building CXX object tests/CMakeFiles/tests.dir/autograd_tests.cpp.o
[ 80%] Building CXX object tests/CMakeFiles/tests.dir/arg_reduce_tests.cpp.o
[ 81%] Building CXX object tests/CMakeFiles/tests.dir/blas_tests.cpp.o
[ 82%] Building CXX object tests/CMakeFiles/tests.dir/eval_tests.cpp.o
[ 83%] Building CXX object tests/CMakeFiles/tests.dir/graph_optimize_tests.cpp.o
[ 84%] Building CXX object tests/CMakeFiles/tests.dir/creations_tests.cpp.o
[ 86%] Building CXX object tests/CMakeFiles/tests.dir/device_tests.cpp.o
[ 87%] Building CXX object tests/CMakeFiles/tests.dir/fft_tests.cpp.o
[ 88%] Building CXX object tests/CMakeFiles/tests.dir/load_tests.cpp.o
[ 89%] Building CXX object tests/CMakeFiles/tests.dir/ops_tests.cpp.o
[ 90%] Building CXX object tests/CMakeFiles/tests.dir/metal_tests.cpp.o
[ 91%] Building CXX object tests/CMakeFiles/tests.dir/vmap_tests.cpp.o
[ 93%] Building CXX object tests/CMakeFiles/tests.dir/scheduler_tests.cpp.o
[ 94%] Building CXX object tests/CMakeFiles/tests.dir/utils_tests.cpp.o
[ 95%] Building CXX object tests/CMakeFiles/tests.dir/random_tests.cpp.o
[ 96%] Linking CXX executable linear_regression
[ 97%] Linking CXX executable tutorial
[ 98%] Linking CXX executable logistic_regression
[ 98%] Built target linear_regression
[ 98%] Built target logistic_regression
[ 98%] Built target tutorial
[100%] Linking CXX executable tests
[100%] Built target tests
可以看到编译了三个示例。
[ 96%] Linking CXX executable linear_regression
[ 97%] Linking CXX executable tutorial
[ 98%] Linking CXX executable logistic_regression
[ 98%] Built target linear_regression
[ 98%] Built target logistic_regression
[ 98%] Built target tutorial
如果没有看到,说明没有正确设置环境变量MLX_BUILD_EXAMPLES
。
运行测试。
cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && make test
测试结果报错。
Running tests...
Test project /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build
Start 1: tests
1/1 Test #1: tests ............................***Failed 1.06 sec
0% tests passed, 1 tests failed out of 1
Total Test time (real) = 1.06 sec
The following tests FAILED:
1 - tests (Failed)
Errors while running CTest
Output from these tests are in: /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/Testing/Temporary/LastTest.log
Use "--rerun-failed --output-on-failure" to re-run the failed cases verbosely.
make: *** [test] Error 8
不必担忧,单元测试没有全部通过是正常的,全部通过才罕见呢,看日志文件。
cat /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/Testing/Temporary/LastTest.log
测试日志文件内容如下。
Start testing: Dec 16 11:49 CST
----------------------------------------------------------
1/1 Testing: tests
1/1 Test: tests
Command: "/Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/tests/tests"
Directory: /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/tests
"tests" start time: Dec 16 11:49 CST
Output:
----------------------------------------------------------
[doctest] doctest version is "2.4.9"
[doctest] run with "--help" for options
===============================================================================
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:221:
TEST CASE: test grad
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:246: ERROR: CHECK_EQ( dfdx(x).item<float>(), std::exp(1.0f) ) is NOT correct!
values: CHECK_EQ( 2.71828, 2.71828 )
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:248: ERROR: CHECK_EQ( d2fdx2(x).item<float>(), std::exp(1.0f) ) is NOT correct!
values: CHECK_EQ( 2.71828, 2.71828 )
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:250: ERROR: CHECK_EQ( d3fdx3(x).item<float>(), std::exp(1.0f) ) is NOT correct!
values: CHECK_EQ( 2.71828, 2.71828 )
===============================================================================
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:358:
TEST CASE: test op vjps
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:402: ERROR: CHECK_EQ( out.second.item<float>(), 2.0f * std::exp(1.0f) ) is NOT correct!
values: CHECK_EQ( 5.43656, 5.43656 )
===============================================================================
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:467:
TEST CASE: test reduction ops
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:680: WARNING: WARN_EQ( logsumexp(x).item<float>(), -inf ) is NOT correct!
values: WARN_EQ( nan, -inf )
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:686: WARNING: WARN_EQ( logsumexp(x).item<float>(), inf ) is NOT correct!
values: WARN_EQ( nan, inf )
===============================================================================
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:741:
TEST CASE: test arithmetic unary ops
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:858: ERROR: CHECK( array_equal(exp(x), full({2, 2, 2}, std::exp(1.0f))).item<bool>() ) is NOT correct!
values: CHECK( false )
===============================================================================
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:1929:
TEST CASE: test power
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:1945: ERROR: CHECK_EQ( (x ^ 0.5).item<float>(), std::pow(2.0f, 0.5f) ) is NOT correct!
values: CHECK_EQ( 1.41421, 1.41421 )
===============================================================================
[doctest] test cases: 113 | 109 passed | 4 failed | 0 skipped
[doctest] assertions: 1962 | 1956 passed | 6 failed |
[doctest] Status: FAILURE!
<end of output>
Test time = 1.06 sec
----------------------------------------------------------
Test Failed.
"tests" end time: Dec 16 11:49 CST
"tests" time elapsed: 00:00:01
----------------------------------------------------------
End testing: Dec 16 11:49 CST
可以看到单元测试用例通过的比例挺高的。
[doctest] test cases: 113 | 109 passed | 4 failed | 0 skipped
[doctest] assertions: 1962 | 1956 passed | 6 failed |
安装,因为涉及到文件夹权限,需要使用sudo
。
cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && sudo make install
输出内容如下。
[ 17%] Built target mlx-metallib
[ 70%] Built target mlx
[ 93%] Built target tests
[ 95%] Built target tutorial
[ 97%] Built target linear_regression
[100%] Built target logistic_regression
Install the project...
-- Install configuration: ""
查看安装的文件。
ls -al /usr/local/lib
可以看到mlx
相关文件。
drwxr-xr-x 5 root wheel 160 12 16 08:13 .
drwxr-xr-x 6 root wheel 192 12 16 08:13 ..
drwxr-xr-x 3 root wheel 96 12 16 08:13 cmake
-rw-r--r-- 1 root wheel 66139056 12 16 08:13 libmlx.a
-rw-r--r-- 1 root wheel 61726901 12 16 08:10 mlx.metallib
运行示例中的linear_regression
cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && ./examples/cpp/linear_regression
输出结果如下。
Loss array(4.64954e-05, dtype=float32), |w - w*| = 0.00363933, Throughput 2685.41 (it/s).
运行示例中的logistic_regression
。
cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && ./examples/cpp/logistic_regression
输出结果如下。
Loss array(0.0289869, dtype=float32), Accuracy, array(1, dtype=float32), Throughput 2251.92 (it/s).
运行示例中的tutorial
。
cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && ./examples/cpp/tutorial
输出结果如下。
array([[1, 1],
[1, 1]], dtype=float32)