Mojo中的矩阵乘法

本文介绍了如何在 Mojo 中编写矩阵乘法 (matmul) 算法。我们将从纯 Python 实现开始,过渡到本质上是 Python 实现副本的简单实现,然后添加类型,然后通过矢量化、平铺和并行化实现继续进行优化。

首先,让我们定义矩阵乘法。给定两个密度矩阵A和B的维数分别是MXK和KXN。我们要计算它们的点积C= A·B(也被称为matmul),点积定义为:C+=A·B, C

在这里插入图片描述

请查看我们关于matmul的博客文章,以及为什么它对于ML和DL工作负载很重要。

该notebook的格式是从一个与Python相同的实现开始(实际上是重命名文件扩展名),然后在通过利用现代硬件上可用的矢量化和并行化能力扩展实现之前,看看向实现中添加类型如何帮助性能。在整个执行过程中,我们报告实现的GFlops。

Python 实现


让我们首先在Python中直接从定义中实现matmul。

%%python
def matmul_python(C, A, B):
    for m in range(C.rows):
        for k in range(A.cols):
            for n in range(C.cols):
                C[m, n] += A[m, k] * B[k, n]

让我们使用128乘128的方阵对我们的实现进行基准测试,并报告实现的GFLops。

安装numpy(如果还没有安装的话):

%%python
from importlib.util import find_spec
import shutil
import subprocess

fix = """
-------------------------------------------------------------------------
fix following the steps here:
    https://github.com/modularml/mojo/issues/1085#issuecomment-1771403719
-------------------------------------------------------------------------
"""

def install_if_missing(name: str):
    if find_spec(name):
        return

    print(f"{
     name} not found, installing...")
    try:
        if shutil.which('python3'): python = "python3"
        elif shutil.which('python'): python = "python"
        else: raise ("python not on path" + fix)
        subprocess.check_call([python, "-m", "pip", "install", name])
    except:
        raise ImportError(f"{
     name} not found" + fix)

install_if_missing("numpy")
%%python
from timeit import timeit
import numpy as np

class Matrix:
    def __init__(self, value, rows, cols):
        self.value = value
        self.rows = rows
        self.cols = cols

    def __getitem__(self, idxs):
        return self.value[idxs[0]][idxs[1]]

    def __setitem__(self, idxs, value):
        self.value[idxs[0]][idxs[1]] = value

def benchmark_matmul_python(M, N, K):
    A = Matrix(list(np.random.rand(M, K)), M, K)
    B = Matrix(list(np.random.rand(K, N)), K, N)
    C = Matrix(list(np.zeros((M, N))), M, N)
    secs = timeit(lambda: matmul_python(C, A, B), number=2)/2
    gflops = ((2*M*N*K)/secs) / 1e9
    print(gflops, "GFLOP/s")
    return gflops
python_gflops = benchmark_matmul_python(128, 128, 128).to_float64()

输出:

0.0022564560694965834 GFLOP/s

将 Python 实现导入 Mojo


使用Mojo和使用Python一样简单。首先,我们从Mojo的stdlib中引入要用到的模块:<

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

启航学途

您的鼓励将是我创作的最大动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值