Lua实现矩阵的加减乘除

Lua实现矩阵的加减乘除

参考文章:

https://blog.csdn.net/qq_54180412/article/details/122943327

https://www.bilibili.com/video/BV1434y187uj/?spm_id_from=333.337.search-card.all.click&vd_source=c52914e7c156e681a9f8a7825139815d

  矩阵的加、减、乘无需多说,主要是矩阵的除法,我使用的是伴随矩阵以及结合行列式的思路求解矩阵之间的除法。

版本一

Matrix = {
    __add = function(tbSource, tbDest)
        assert(tbSource,"tbSource not exist")
        assert(tbDest,  "tbDest not exist")
        if tbSource.nRow ~= tbDest.nRow 
            or tbSource.nColumn ~= tbDest.nColumn then
            print("row or column not equal...")
            return tbSource
        else
            for rowKey,rowValue in ipairs(tbSource.tbData) do
                for colKey,colValue in ipairs(tbSource.tbData[rowKey]) do
                    tbSource.tbData[rowKey][colKey] = 
                        tbSource.tbData[rowKey][colKey] + tbDest.tbData[rowKey][colKey]
                end
            end 
            return tbSource
        end
    end,

    __sub = function(tbSource, tbDest)
        assert(tbSource,"tbSource not exist")
        assert(tbDest,  "tbDest not exist")
        if tbSource.nRow ~= tbDest.nRow 
            or tbSource.nColumn ~= tbDest.nColumn then
            print("row or column not equal...")
            return tbSource
        else
            for rowKey,rowValue in ipairs(tbSource.tbData) do
                for colKey,colValue in ipairs(tbSource.tbData[rowKey]) do
                    tbSource.tbData[rowKey][colKey] = 
                        tbSource.tbData[rowKey][colKey] - tbDest.tbData[rowKey][colKey]
                end
            end 
            return tbSource
        end
    end,

    __mul = function(tbSource, tbDest)
        return Matrix:_MartixMul(tbSource, tbDest)
    end,

    __div = function(tbSource, tbDest)
        assert(tbSource,"tbSource not exist")
        assert(tbDest,  "tbDest not exist")
        
        local nDet = Matrix:_GetDetValue(tbDest)
        if nDet == 0 then
            print("matrix no inverse matrix...")
            return nil
        end
        local tbInverseDest = Matrix:_MatrixNumMul(Matrix:_GetCompanyMatrix(tbDest), 1 / nDet)
        -- Matrix:_GetCompanyMatrix(tbDest):Print()
        -- print(nDet)
        -- tbInverseDest:Print()
        return Matrix:_MartixMul(tbSource, tbInverseDest)
    end
}

-- 构建一个Matrix对象
function Matrix:New(data)
    assert(data,"data not exist")
    local tbMatrix = {}
    setmetatable(tbMatrix, self)
    self.__index    = self

    tbMatrix.tbData = data
    tbMatrix.nRow   = #data
    if tbMatrix.nRow > 0 then
        tbMatrix.nColumn = (#data[1])
    else
        tbMatrix.nColumn = 0
    end
    return tbMatrix
end

-- 输出当前Matrix数据信息
function Matrix:Print()
    for rowKey,rowValue in ipairs(self.tbData) do
        for colKey,colValue in ipairs(self.tbData[rowKey]) do
            io.write(self.tbData[rowKey][colKey],',')
        end
        print('')
    end
end

-- 切割,切去第rowIndex以及第colIndex列
function Matrix:_CutoffMatrix(tbMatrix, rowIndex, colIndex)
    assert(tbMatrix,"tbMatrix not exist")
    assert(rowIndex >= 1,"rowIndex < 1")
    assert(colIndex >= 1,"colIndex < 1")
    local tbRes   = Matrix:New({})
    tbRes.nRow    = tbMatrix.nRow    - 1
    tbRes.nColumn = tbMatrix.nColumn - 1
    for i = 1, tbMatrix.nRow - 1 do
        for j = 1, tbMatrix.nColumn - 1 do
            if tbRes.tbData[i] == nil then
                tbRes.tbData[i] = {}
            end
            
            local nRowDir = 0
            local nColDir = 0
            if i >= rowIndex then
                nRowDir = 1
            end

            if j >= colIndex then
                nColDir = 1
            end

            tbRes.tbData[i][j] = tbMatrix.tbData[i + nRowDir][j + nColDir]
        end
    end
    return tbRes
end


-- 获取矩阵的行列式对应的值
function Matrix:_GetDetValue(tbMatrix)
    assert(tbMatrix,"tbMatrix not exist")
    -- 当矩阵为一阶矩阵时,直接返回A中唯一的元素
    if tbMatrix.nRow == 1 then
        return tbMatrix.tbData[1][1]
    end

    local nAns = 0
    for i = 1, tbMatrix.nColumn do
        local nFlag = -1
        if i % 2 ~= 0 then
            nFlag = 1
        end

        nAns = 
            nAns + tbMatrix.tbData[1][i] * 
                Matrix:_GetDetValue(Matrix:_CutoffMatrix(tbMatrix, 1, i)) * nFlag
        -- print("_GetDetValue nflag:",nFlag)
    end
    return nAns
end


-- 获取矩阵的伴随矩阵
function Matrix:_GetCompanyMatrix(tbMatrix)
    assert(tbMatrix,"tbMatrix not exist")
    local tbRes   = Matrix:New({})
    -- 伴随矩阵与原矩阵存在转置关系
    tbRes.nRow    = tbMatrix.nColumn
    tbRes.nColumn = tbMatrix.nRow

    for i = 1, tbMatrix.nRow do
        for j = 1, tbMatrix.nColumn do
            local nFlag = 1
            if ((i + j) % 2) ~= 0 then
                nFlag = -1
            end       
            
            if tbRes.tbData[j] == nil then
                tbRes.tbData[j] = {}
            end
            -- print(Matrix:_GetDetValue(Matrix:_CutoffMatrix(tbMatrix, i, j)))
            -- Matrix:_CutoffMatrix(tbMatrix, i, j):Print()
            -- print("---11----")

            tbRes.tbData[j][i] = 
                Matrix:_GetDetValue(Matrix:_CutoffMatrix(tbMatrix, i, j)) * nFlag
        end
    end
    return tbRes
end


-- 矩阵数乘
function Matrix:_MatrixNumMul(tbMatrix, num)
    for i = 1, tbMatrix.nRow do
        for j = 1, tbMatrix.nColumn do
            tbMatrix.tbData[i][j] = tbMatrix.tbData[i][j] * num
        end
    end
    return tbMatrix
end


-- 矩阵相乘
function Matrix:_MartixMul(tbSource, tbDest)
    assert(tbSource,"tbSource not exist")
    assert(tbDest,  "tbDest not exist")
    if tbSource.nColumn ~= tbDest.nRow then
        print("column not equal row...")
        return tbSource
    else
        local tbRes = Matrix:New({})
        for i = 1, tbSource.nRow do
            for j = 1, tbDest.nColumn do
                if tbRes.tbData[i] == nil then
                    tbRes.tbData[i] = {}
                end
                
                if tbRes.tbData[i][j] == nil then
                    tbRes.tbData[i][j] = 0
                end

                for k = 1, tbSource.nColumn do
                    tbRes.tbData[i][j] = 
                        tbRes.tbData[i][j] + (tbSource.tbData[i][k] * tbDest.tbData[k][j])
                end
            end
        end
        tbRes.nRow    = tbSource.nRow
        tbRes.nColumn = tbDest.nColumn
        return tbRes
    end
end









-- 矩阵加法
local matrix1 = Matrix:New({{1,2,3},{4,5,6}})
local matrix2 = Matrix:New({{2,3,4},{5,6,7}})

matrix1 = matrix1 + matrix2
matrix1:Print()
print("-----------------------------------")

-- 矩阵减法
local matrix3 = Matrix:New({{1,1,1},{1,1,1}})
matrix1 = matrix1 - matrix3
matrix1:Print()
print("-----------------------------------")


-- 矩阵乘法
local matrix4 = Matrix:New({{1,2,3},{4,5,6}})
local matrix5 = Matrix:New({{7,8},{9,10},{11,12}})

local matrix6 = matrix4 * matrix5
matrix6:Print()
print("-----------------------------------")


-- 矩阵除法
local matrix7 = Matrix:New({{1,2,3},{4,5,6},{7,8,0}})
local matrix8 = Matrix:New({{1,2,1},{1,1,2},{2,1,1}})
local matrix9 = matrix7 / matrix8
matrix9:Print()

  输出结果:

效果图

版本二

  在使用第一版实现后,发现对Lua面向对象的思维不够清晰,顾查阅网上资料实现以下版本。

  Lua面向对象类的封装这里不做过多阐述,可参考文章:https://blog.csdn.net/qq135595696/article/details/128827618

local _class = {}

function class(super)
    local tbClassType = {}
    tbClassType.Ctor  = false
    tbClassType.super = super
    tbClassType.New   = function(...)
        local tbObj   = {}
        do
            local funcCreate
            funcCreate = function(tbClass,...)
                if tbClass.super then
                    funcCreate(tbClass.super,...)
                end
                
                if tbClass.Ctor then
                    tbClass.Ctor(tbObj,...)
                end
            end
            funcCreate(tbClassType,...)
        end
        -- 防止调用Ctor初始化时,在Ctor内部设置了元表的情况发生
        if getmetatable(tbObj) then
            getmetatable(tbObj).__index = _class[tbClassType] 
        else
            setmetatable(tbObj, { __index = _class[tbClassType] })
        end
        return tbObj
    end

    local vtbl          = {} 
    _class[tbClassType] = vtbl

    setmetatable(tbClassType, { __newindex = 
        function(tb,k,v)
            vtbl[k] = v
        end
    })

    if super then
        setmetatable(vtbl, { __index = 
            function(tb,k)
                local varRet = _class[super][k]
                vtbl[k]      = varRet
                return varRet
            end
        })
    end
    return tbClassType
end







Matrix = class()

function Matrix:Ctor(data)
    self.tbData = data
    self.nRow   = #data
    if self.nRow > 0 then
        self.nColumn = (#data[1])
    else
        self.nColumn = 0
    end

    setmetatable(self,{
        __add = function(tbSource, tbDest)
            assert(tbSource,"tbSource not exist")
            assert(tbDest,  "tbDest not exist")
            if tbSource.nRow ~= tbDest.nRow 
                or tbSource.nColumn ~= tbDest.nColumn then
                print("row or column not equal...")
                return tbSource
            else
                for rowKey,rowValue in ipairs(tbSource.tbData) do
                    for colKey,colValue in ipairs(tbSource.tbData[rowKey]) do
                        tbSource.tbData[rowKey][colKey] = 
                            tbSource.tbData[rowKey][colKey] + tbDest.tbData[rowKey][colKey]
                    end
                end 
                return tbSource
            end
        end,
    
        __sub = function(tbSource, tbDest)
            assert(tbSource,"tbSource not exist")
            assert(tbDest,  "tbDest not exist")
            if tbSource.nRow ~= tbDest.nRow 
                or tbSource.nColumn ~= tbDest.nColumn then
                print("row or column not equal...")
                return tbSource
            else
                for rowKey,rowValue in ipairs(tbSource.tbData) do
                    for colKey,colValue in ipairs(tbSource.tbData[rowKey]) do
                        tbSource.tbData[rowKey][colKey] = 
                            tbSource.tbData[rowKey][colKey] - tbDest.tbData[rowKey][colKey]
                    end
                end 
                return tbSource
            end
        end,
    
        __mul = function(tbSource, tbDest)
            return self:_MartixMul(tbSource, tbDest)
        end,
    
        __div = function(tbSource, tbDest)
            assert(tbSource,"tbSource not exist")
            assert(tbDest,  "tbDest not exist")
            
            local nDet = self:_GetDetValue(tbDest)
            if nDet == 0 then
                print("matrix no inverse matrix...")
                return nil
            end
            local tbInverseDest = self:_MatrixNumMul(self:_GetCompanyMatrix(tbDest), 1 / nDet)
            -- Matrix:_GetCompanyMatrix(tbDest):Print()
            -- print(nDet)
            -- tbInverseDest:Print()
            return self:_MartixMul(tbSource, tbInverseDest)
        end
    }
)


end

function Matrix:Print()
    for rowKey,rowValue in ipairs(self.tbData) do
        for colKey,colValue in ipairs(self.tbData[rowKey]) do
            io.write(self.tbData[rowKey][colKey],',')
        end
        print('')
    end
end

-- 切割,切去第rowIndex以及第colIndex列
function Matrix:_CutoffMatrix(tbMatrix, rowIndex, colIndex)
    assert(tbMatrix,"tbMatrix not exist")
    assert(rowIndex >= 1,"rowIndex < 1")
    assert(colIndex >= 1,"colIndex < 1")
    local tbRes   = Matrix.New({})
    tbRes.nRow    = tbMatrix.nRow    - 1
    tbRes.nColumn = tbMatrix.nColumn - 1
    for i = 1, tbMatrix.nRow - 1 do
        for j = 1, tbMatrix.nColumn - 1 do
            if tbRes.tbData[i] == nil then
                tbRes.tbData[i] = {}
            end
            
            local nRowDir = 0
            local nColDir = 0
            if i >= rowIndex then
                nRowDir = 1
            end

            if j >= colIndex then
                nColDir = 1
            end

            tbRes.tbData[i][j] = tbMatrix.tbData[i + nRowDir][j + nColDir]
        end
    end
    return tbRes
end

-- 获取矩阵的行列式对应的值
function Matrix:_GetDetValue(tbMatrix)
    assert(tbMatrix,"tbMatrix not exist")
    -- 当矩阵为一阶矩阵时,直接返回A中唯一的元素
    if tbMatrix.nRow == 1 then
        return tbMatrix.tbData[1][1]
    end

    local nAns = 0
    for i = 1, tbMatrix.nColumn do
        local nFlag = -1
        if i % 2 ~= 0 then
            nFlag = 1
        end

        nAns = 
            nAns + tbMatrix.tbData[1][i] * 
                self:_GetDetValue(self:_CutoffMatrix(tbMatrix, 1, i)) * nFlag
        -- print("_GetDetValue nflag:",nFlag)
    end
    return nAns
end


-- 获取矩阵的伴随矩阵
function Matrix:_GetCompanyMatrix(tbMatrix)
    assert(tbMatrix,"tbMatrix not exist")
    local tbRes   = Matrix.New({})
    -- 伴随矩阵与原矩阵存在转置关系
    tbRes.nRow    = tbMatrix.nColumn
    tbRes.nColumn = tbMatrix.nRow

    for i = 1, tbMatrix.nRow do
        for j = 1, tbMatrix.nColumn do
            local nFlag = 1
            if ((i + j) % 2) ~= 0 then
                nFlag = -1
            end       
            
            if tbRes.tbData[j] == nil then
                tbRes.tbData[j] = {}
            end
            -- print(Matrix:_GetDetValue(Matrix:_CutoffMatrix(tbMatrix, i, j)))
            -- Matrix:_CutoffMatrix(tbMatrix, i, j):Print()
            -- print("---11----")

            tbRes.tbData[j][i] = 
                self:_GetDetValue(self:_CutoffMatrix(tbMatrix, i, j)) * nFlag
        end
    end
    return tbRes
end


-- 矩阵数乘
function Matrix:_MatrixNumMul(tbMatrix, num)
    for i = 1, tbMatrix.nRow do
        for j = 1, tbMatrix.nColumn do
            tbMatrix.tbData[i][j] = tbMatrix.tbData[i][j] * num
        end
    end
    return tbMatrix
end


-- 矩阵相乘
function Matrix:_MartixMul(tbSource, tbDest)
    assert(tbSource,"tbSource not exist")
    assert(tbDest,  "tbDest not exist")
    if tbSource.nColumn ~= tbDest.nRow then
        print("column not equal row...")
        return tbSource
    else
        local tbRes = Matrix.New({})
        for i = 1, tbSource.nRow do
            for j = 1, tbDest.nColumn do
                if tbRes.tbData[i] == nil then
                    tbRes.tbData[i] = {}
                end
                
                if tbRes.tbData[i][j] == nil then
                    tbRes.tbData[i][j] = 0
                end

                for k = 1, tbSource.nColumn do
                    tbRes.tbData[i][j] = 
                        tbRes.tbData[i][j] + (tbSource.tbData[i][k] * tbDest.tbData[k][j])
                end
            end
        end
        tbRes.nRow    = tbSource.nRow
        tbRes.nColumn = tbDest.nColumn
        return tbRes
    end
end



-- 矩阵加法
local matrix1 = Matrix.New({{1,2,3},{4,5,6}})
local matrix2 = Matrix.New({{2,3,4},{5,6,7}})

matrix1 = matrix1 + matrix2
matrix1:Print()
print("-----------------------------------")

-- 矩阵减法
local matrix3 = Matrix.New({{1,1,1},{1,1,1}})
matrix1 = matrix1 - matrix3
matrix1:Print()
print("-----------------------------------")


-- 矩阵乘法
local matrix4 = Matrix.New({{1,2,3},{4,5,6}})
local matrix5 = Matrix.New({{7,8},{9,10},{11,12}})

local matrix6 = matrix4 * matrix5
matrix6:Print()
print("-----------------------------------")


-- 矩阵除法
local matrix7 = Matrix.New({{1,2,3},{4,5,6},{7,8,0}})
local matrix8 = Matrix.New({{1,2,1},{1,1,2},{2,1,1}})
local matrix9 = matrix7 / matrix8
matrix9:Print()

  输出结果:

效果图

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ufgnix0802

总结不易,谢谢大家的支持

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值