本文介绍了如何在 Mojo 中编写矩阵乘法 (matmul) 算法。我们将从纯 Python 实现开始,过渡到本质上是 Python 实现副本的简单实现,然后添加类型,然后通过矢量化、平铺和并行化实现继续进行优化。
首先,让我们定义矩阵乘法。给定两个密度矩阵A和B的维数分别是MXK和KXN。我们要计算它们的点积C= A·B(也被称为matmul),点积定义为:C+=A·B, C
请查看我们关于matmul的博客文章,以及为什么它对于ML和DL工作负载很重要。
该notebook的格式是从一个与Python相同的实现开始(实际上是重命名文件扩展名),然后在通过利用现代硬件上可用的矢量化和并行化能力扩展实现之前,看看向实现中添加类型如何帮助性能。在整个执行过程中,我们报告实现的GFlops。
Python 实现
让我们首先在Python中直接从定义中实现matmul。
%%python
def matmul_python(C, A, B):
for m in range(C.rows):
for k in range(A.cols):
for n in range(C.cols):
C[m, n] += A[m, k] * B[k, n]
让我们使用128乘128的方阵对我们的实现进行基准测试,并报告实现的GFLops。
安装numpy(如果还没有安装的话):
%%python
from importlib.util import find_spec
import shutil
import subprocess
fix = """
-------------------------------------------------------------------------
fix following the steps here:
https://github.com/modularml/mojo/issues/1085#issuecomment-1771403719
-------------------------------------------------------------------------
"""
def install_if_missing(name: str):
if find_spec(name):
return
print(f"{
name} not found, installing...")
try:
if shutil.which('python3'): python = "python3"
elif shutil.which('python'): python = "python"
else: raise ("python not on path" + fix)
subprocess.check_call([python, "-m", "pip", "install", name])
except:
raise ImportError(f"{
name} not found" + fix)
install_if_missing("numpy")
%%python
from timeit import timeit
import numpy as np
class Matrix:
def __init__(self, value, rows, cols):
self.value = value
self.rows = rows
self.cols = cols
def __getitem__(self, idxs):
return self.value[idxs[0]][idxs[1]]
def __setitem__(self, idxs, value):
self.value[idxs[0]][idxs[1]] = value
def benchmark_matmul_python(M, N, K):
A = Matrix(list(np.random.rand(M, K)), M, K)
B = Matrix(list(np.random.rand(K, N)), K, N)
C = Matrix(list(np.zeros((M, N))), M, N)
secs = timeit(lambda: matmul_python(C, A, B), number=2)/2
gflops = ((2*M*N*K)/secs) / 1e9
print(gflops, "GFLOP/s")
return gflops
python_gflops = benchmark_matmul_python(128, 128, 128).to_float64()
输出:
0.0022564560694965834 GFLOP/s
将 Python 实现导入 Mojo
使用Mojo和使用Python一样简单。首先,我们从Mojo的stdlib中引入要用到的模块:<