苹果深度学习框架`MLX`简介及编译示例

苹果深度学习框架MLX简介及编译示例

MLX简介

北京时间2023年12月6日,苹果机器学习研究中心 (Apple machine learning research) 在GitHub上开源MLX
项目地址为 https://github.com/ml-explore/mlx

MLX深度学习框架是苹果专门为Apple Silicon芯片优化,号称能够简化MaciPadiPhone平台研究人员设计和部署模型的过程。

MLX的一些主要功能包括:

  • 熟悉的 API
    MLX具有紧密类似NumPyPython API
    MLX还拥有功能齐全的C++ API,它与Python API非常相似。
    MLX具有像mlx.nnmlx.optimizers这样的更高级别的软件包,
    API紧密类似PyTorch,用于简化构建更复杂的模型。

  • 可组合函数转换
    MLX支持自动微分、自动矢量化和计算图优化等可组合函数的转换。

  • 惰性计算
    MLX中的计算是惰性计算。数组只是在需要时生成。

  • 动态图构建
    MLX中的计算图是动态构建的。更改函数参数的形状不会导致编译速度减慢,调试简单直观。

  • 多设备
    操作可以在任何支持的设备上运行 (目前是CPUGPU)。

  • 统一内存
    MLX和其他框架的显着区别是统一内存模型。MLX中的数组位于共享内存中。
    MLX数组上的操作可以在任何受支持的设备类型上执行,而不需要传输数据。

MLX是由苹果机器学习研究中心的机器学习研究人员为机器学习研究人员而设计的。
该框架旨在用户友好,但仍然高效训练和部署模型。
框架本身的设计也是概念上很简单。
目的是让研究人员能够轻松扩展和改进MLX,以快速探索新想法为目标。

不能简单的把MLX视为造轮子,苹果既然发布了GPU,就自然的挖掘GPU的运算潜力。

苹果的芯片架构与以往主流的芯片架构不同,是统一内存模型。
统一内存模型不同于TensorFlow等框架需要显式管理内存,大大简化了编程模型。

编译MLX示例

由于编译MLX需要使用XCode,因此不适合在容器中编译,需要直接在MacOS中编译。
为便于管理电脑上下载的各种源代码,使用ghq下载。

下载源代码

ghq get --shallow https://github.com/ml-explore/mlx

访达中浏览项目结构。

open -R "$(ghq list --full-path https://github.com/ml-explore/mlx)"

在命令行中浏览项目结构。

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)" && ls -al

内容如下

total 168
drwxr-xr-x  26 huzhenghui  staff    832 12 16 11:41 .
drwxr-xr-x   4 huzhenghui  staff    128 12 16 08:09 ..
-rw-r--r--@  1 huzhenghui  staff   6148 12 16 08:25 .DS_Store
drwxr-xr-x   3 huzhenghui  staff     96 12 16 08:00 .circleci
-rw-r--r--   1 huzhenghui  staff   2552 12 16 08:00 .clang-format
drwxr-xr-x  13 huzhenghui  staff    416 12 16 08:00 .git
drwxr-xr-x   4 huzhenghui  staff    128 12 16 08:00 .github
-rw-r--r--   1 huzhenghui  staff    733 12 16 08:00 .gitignore
-rw-r--r--   1 huzhenghui  staff    433 12 16 08:00 .pre-commit-config.yaml
-rw-r--r--   1 huzhenghui  staff  12320 12 16 08:00 ACKNOWLEDGMENTS.md
-rw-r--r--   1 huzhenghui  staff   6533 12 16 08:00 CMakeLists.txt
-rw-r--r--   1 huzhenghui  staff   5544 12 16 08:00 CODE_OF_CONDUCT.md
-rw-r--r--   1 huzhenghui  staff   1292 12 16 08:00 CONTRIBUTING.md
-rw-r--r--   1 huzhenghui  staff   1066 12 16 08:00 LICENSE
-rw-r--r--   1 huzhenghui  staff     69 12 16 08:00 MANIFEST.in
-rw-r--r--   1 huzhenghui  staff   3523 12 16 08:00 README.md
drwxr-xr-x   5 huzhenghui  staff    160 12 16 08:00 benchmarks
drwxr-xr-x   3 huzhenghui  staff     96 12 16 08:00 cmake
drwxr-xr-x   9 huzhenghui  staff    288 12 16 08:00 docs
drwxr-xr-x   5 huzhenghui  staff    160 12 16 08:00 examples
drwxr-xr-x  35 huzhenghui  staff   1120 12 16 08:00 mlx
-rw-r--r--   1 huzhenghui  staff   1364 12 16 08:00 mlx.pc.in
-rw-r--r--   1 huzhenghui  staff    118 12 16 08:00 pyproject.toml
drwxr-xr-x   6 huzhenghui  staff    192 12 16 08:00 python
-rw-r--r--   1 huzhenghui  staff   6887 12 16 08:00 setup.py
drwxr-xr-x  21 huzhenghui  staff    672 12 16 08:00 tests

创建./build文件夹用于构建。

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)" && mkdir -p build

运行CMake,为编译示例,需要设置MLX_BUILD_EXAMPLES环境变量。

export MLX_BUILD_EXAMPLES=ON
cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && cmake ..

输出如下内容

-- The CXX compiler identification is AppleClang 15.0.0.15000100
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Building MLX for arm64 processor on Darwin
-- Building METAL sources
-- Building with SDK for macOS version 14.2

-- Accelerate found /Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.2.sdk/System/Library/Frameworks/Accelerate.framework
CMake Deprecation Warning at build/_deps/doctest-src/CMakeLists.txt:1 (cmake_minimum_required):
  Compatibility with CMake < 3.5 will be removed from a future version of
  CMake.

  Update the VERSION argument <min> value or use a ...<max> suffix to tell
  CMake that the project does not need compatibility with older versions.


-- Configuring done (24.2s)
-- Generating done (0.1s)
-- Build files have been written to: /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build

运行make编译,为提升速度,使用--jobs并行编译。

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && make --jobs

运行结果如下。

[  1%] Building unary.air
[  3%] Building reduce.air
[  4%] Building indexing.air
[  5%] Building arange.air
[  6%] Building sort.air
[  6%] Building softmax.air
[  9%] Building copy.air
[  9%] Building random.air
[ 10%] Building scan.air
[ 11%] Building gemv.air
[ 12%] Building conv.air
[ 13%] Building gemm.air
[ 15%] Building binary.air
[ 16%] Building arg_reduce.air
[ 17%] Building mlx.metallib
[ 17%] Built target mlx-metallib
[ 18%] Building CXX object CMakeFiles/mlx.dir/mlx/device.cpp.o
[ 19%] Building CXX object CMakeFiles/mlx.dir/mlx/allocator.cpp.o
[ 22%] Building CXX object CMakeFiles/mlx.dir/mlx/scheduler.cpp.o
[ 22%] Building CXX object CMakeFiles/mlx.dir/mlx/graph_utils.cpp.o
[ 23%] Building CXX object CMakeFiles/mlx.dir/mlx/transforms.cpp.o
[ 25%] Building CXX object CMakeFiles/mlx.dir/mlx/primitives.cpp.o
[ 25%] Building CXX object CMakeFiles/mlx.dir/mlx/fft.cpp.o
[ 26%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/reduce.cpp.o
[ 27%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/fft.cpp.o
[ 29%] Building CXX object CMakeFiles/mlx.dir/mlx/load.cpp.o
[ 30%] Building CXX object CMakeFiles/mlx.dir/mlx/array.cpp.o
[ 32%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/conv.cpp.o
[ 32%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/indexing.cpp.o
[ 33%] Building CXX object CMakeFiles/mlx.dir/mlx/dtype.cpp.o
[ 36%] Building CXX object CMakeFiles/mlx.dir/mlx/ops.cpp.o
[ 36%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/binary.cpp.o
[ 37%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/scan.cpp.o
[ 38%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/erf.cpp.o
[ 39%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/conv.cpp.o
[ 41%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/primitives.cpp.o
[ 41%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/copy.cpp.o
[ 43%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/metal.cpp.o
[ 44%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/reduce.cpp.o
[ 45%] Building CXX object CMakeFiles/mlx.dir/mlx/utils.cpp.o
[ 46%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/softmax.cpp.o
[ 47%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/copy.cpp.o
[ 48%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/arg_reduce.cpp.o
[ 50%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/primitives.cpp.o
[ 51%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/primitives.cpp.o
[ 52%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/load.cpp.o
[ 53%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/conv.cpp.o
[ 55%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/matmul.cpp.o
[ 55%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/device.cpp.o
[ 58%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/allocator.cpp.o
[ 58%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/fft.cpp.o
[ 59%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/threefry.cpp.o
[ 60%] Building CXX object CMakeFiles/mlx.dir/mlx/random.cpp.o
[ 61%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/common/sort.cpp.o
[ 63%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/accelerate/softmax.cpp.o
[ 63%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/scan.cpp.o
[ 65%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/matmul.cpp.o
[ 66%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/indexing.cpp.o
[ 67%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/reduce.cpp.o
[ 68%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/sort.cpp.o
[ 69%] Building CXX object CMakeFiles/mlx.dir/mlx/backend/metal/softmax.cpp.o
[ 70%] Linking CXX static library libmlx.a
[ 70%] Built target mlx
[ 74%] Building CXX object examples/cpp/CMakeFiles/linear_regression.dir/linear_regression.cpp.o
[ 74%] Building CXX object examples/cpp/CMakeFiles/tutorial.dir/tutorial.cpp.o
[ 74%] Building CXX object tests/CMakeFiles/tests.dir/tests.cpp.o
[ 76%] Building CXX object tests/CMakeFiles/tests.dir/array_tests.cpp.o
[ 76%] Building CXX object tests/CMakeFiles/tests.dir/allocator_tests.cpp.o
[ 77%] Building CXX object examples/cpp/CMakeFiles/logistic_regression.dir/logistic_regression.cpp.o
[ 79%] Building CXX object tests/CMakeFiles/tests.dir/autograd_tests.cpp.o
[ 80%] Building CXX object tests/CMakeFiles/tests.dir/arg_reduce_tests.cpp.o
[ 81%] Building CXX object tests/CMakeFiles/tests.dir/blas_tests.cpp.o
[ 82%] Building CXX object tests/CMakeFiles/tests.dir/eval_tests.cpp.o
[ 83%] Building CXX object tests/CMakeFiles/tests.dir/graph_optimize_tests.cpp.o
[ 84%] Building CXX object tests/CMakeFiles/tests.dir/creations_tests.cpp.o
[ 86%] Building CXX object tests/CMakeFiles/tests.dir/device_tests.cpp.o
[ 87%] Building CXX object tests/CMakeFiles/tests.dir/fft_tests.cpp.o
[ 88%] Building CXX object tests/CMakeFiles/tests.dir/load_tests.cpp.o
[ 89%] Building CXX object tests/CMakeFiles/tests.dir/ops_tests.cpp.o
[ 90%] Building CXX object tests/CMakeFiles/tests.dir/metal_tests.cpp.o
[ 91%] Building CXX object tests/CMakeFiles/tests.dir/vmap_tests.cpp.o
[ 93%] Building CXX object tests/CMakeFiles/tests.dir/scheduler_tests.cpp.o
[ 94%] Building CXX object tests/CMakeFiles/tests.dir/utils_tests.cpp.o
[ 95%] Building CXX object tests/CMakeFiles/tests.dir/random_tests.cpp.o
[ 96%] Linking CXX executable linear_regression
[ 97%] Linking CXX executable tutorial
[ 98%] Linking CXX executable logistic_regression
[ 98%] Built target linear_regression
[ 98%] Built target logistic_regression
[ 98%] Built target tutorial
[100%] Linking CXX executable tests
[100%] Built target tests

可以看到编译了三个示例。

[ 96%] Linking CXX executable linear_regression
[ 97%] Linking CXX executable tutorial
[ 98%] Linking CXX executable logistic_regression
[ 98%] Built target linear_regression
[ 98%] Built target logistic_regression
[ 98%] Built target tutorial

如果没有看到,说明没有正确设置环境变量MLX_BUILD_EXAMPLES

运行测试。

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && make test

测试结果报错。

Running tests...
Test project /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build
    Start 1: tests
1/1 Test #1: tests ............................***Failed    1.06 sec

0% tests passed, 1 tests failed out of 1

Total Test time (real) =   1.06 sec

The following tests FAILED:
          1 - tests (Failed)
Errors while running CTest
Output from these tests are in: /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/Testing/Temporary/LastTest.log
Use "--rerun-failed --output-on-failure" to re-run the failed cases verbosely.
make: *** [test] Error 8

不必担忧,单元测试没有全部通过是正常的,全部通过才罕见呢,看日志文件。

cat /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/Testing/Temporary/LastTest.log

测试日志文件内容如下。

Start testing: Dec 16 11:49 CST
----------------------------------------------------------
1/1 Testing: tests
1/1 Test: tests
Command: "/Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/tests/tests"
Directory: /Users/huzhenghui/ghq/github.com/ml-explore/mlx/build/tests
"tests" start time: Dec 16 11:49 CST
Output:
----------------------------------------------------------
[doctest] doctest version is "2.4.9"
[doctest] run with "--help" for options
===============================================================================
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:221:
TEST CASE:  test grad

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:246: ERROR: CHECK_EQ( dfdx(x).item<float>(), std::exp(1.0f) ) is NOT correct!
  values: CHECK_EQ( 2.71828, 2.71828 )

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:248: ERROR: CHECK_EQ( d2fdx2(x).item<float>(), std::exp(1.0f) ) is NOT correct!
  values: CHECK_EQ( 2.71828, 2.71828 )

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:250: ERROR: CHECK_EQ( d3fdx3(x).item<float>(), std::exp(1.0f) ) is NOT correct!
  values: CHECK_EQ( 2.71828, 2.71828 )

===============================================================================
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:358:
TEST CASE:  test op vjps

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/autograd_tests.cpp:402: ERROR: CHECK_EQ( out.second.item<float>(), 2.0f * std::exp(1.0f) ) is NOT correct!
  values: CHECK_EQ( 5.43656, 5.43656 )

===============================================================================
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:467:
TEST CASE:  test reduction ops

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:680: WARNING: WARN_EQ( logsumexp(x).item<float>(), -inf ) is NOT correct!
  values: WARN_EQ( nan, -inf )

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:686: WARNING: WARN_EQ( logsumexp(x).item<float>(), inf ) is NOT correct!
  values: WARN_EQ( nan, inf )

===============================================================================
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:741:
TEST CASE:  test arithmetic unary ops

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:858: ERROR: CHECK( array_equal(exp(x), full({2, 2, 2}, std::exp(1.0f))).item<bool>() ) is NOT correct!
  values: CHECK( false )

===============================================================================
/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:1929:
TEST CASE:  test power

/Users/huzhenghui/ghq/github.com/ml-explore/mlx/tests/ops_tests.cpp:1945: ERROR: CHECK_EQ( (x ^ 0.5).item<float>(), std::pow(2.0f, 0.5f) ) is NOT correct!
  values: CHECK_EQ( 1.41421, 1.41421 )

===============================================================================
[doctest] test cases:  113 |  109 passed | 4 failed | 0 skipped
[doctest] assertions: 1962 | 1956 passed | 6 failed |
[doctest] Status: FAILURE!
<end of output>
Test time =   1.06 sec
----------------------------------------------------------
Test Failed.
"tests" end time: Dec 16 11:49 CST
"tests" time elapsed: 00:00:01
----------------------------------------------------------

End testing: Dec 16 11:49 CST

可以看到单元测试用例通过的比例挺高的。

[doctest] test cases:  113 |  109 passed | 4 failed | 0 skipped
[doctest] assertions: 1962 | 1956 passed | 6 failed |

安装,因为涉及到文件夹权限,需要使用sudo

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && sudo make install

输出内容如下。

[ 17%] Built target mlx-metallib
[ 70%] Built target mlx
[ 93%] Built target tests
[ 95%] Built target tutorial
[ 97%] Built target linear_regression
[100%] Built target logistic_regression
Install the project...
-- Install configuration: ""

查看安装的文件。

ls -al /usr/local/lib

可以看到mlx相关文件。

drwxr-xr-x  5 root  wheel       160 12 16 08:13 .
drwxr-xr-x  6 root  wheel       192 12 16 08:13 ..
drwxr-xr-x  3 root  wheel        96 12 16 08:13 cmake
-rw-r--r--  1 root  wheel  66139056 12 16 08:13 libmlx.a
-rw-r--r--  1 root  wheel  61726901 12 16 08:10 mlx.metallib

运行示例中的linear_regression

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && ./examples/cpp/linear_regression

输出结果如下。

Loss array(4.64954e-05, dtype=float32), |w - w*| = 0.00363933, Throughput 2685.41 (it/s).

运行示例中的logistic_regression

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && ./examples/cpp/logistic_regression

输出结果如下。

Loss array(0.0289869, dtype=float32), Accuracy, array(1, dtype=float32), Throughput 2251.92 (it/s).

运行示例中的tutorial

cd "$(ghq list --full-path https://github.com/ml-explore/mlx)"/build && ./examples/cpp/tutorial

输出结果如下。

array([[1, 1],
       [1, 1]], dtype=float32)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值