luapy (16) code gen

10 篇文章 2 订阅

1.func info

from lua_opcode import OpCode
from lua_token import TokenKind
from lua_opcode import Instruction
from lua_value import LuaValue
from prototype import Prototype
from upvalue import Upvalue


class LocalVarInfo:
    def __init__(self, name, prev, scope_level, slot):
        self.name = name
        self.prev = prev
        self.scope_level = scope_level
        self.slot = slot
        self.captured = False


class UpvalInfo:
    def __init__(self, slot, upval_index, index):
        self.local_var_slot = slot
        self.upval_index = upval_index
        self.index = index


class FuncInfo:
    arith_and_bitwise_binops = {
        TokenKind.OP_ADD: OpCode.ADD,
        TokenKind.OP_SUB: OpCode.SUB,
        TokenKind.OP_MUL: OpCode.MUL,
        TokenKind.OP_MOD: OpCode.MOD,
        TokenKind.OP_POW: OpCode.POW,
        TokenKind.OP_DIV: OpCode.DIV,
        TokenKind.OP_IDIV: OpCode.IDIV,
        TokenKind.OP_BAND: OpCode.BAND,
        TokenKind.OP_BOR: OpCode.BOR,
        TokenKind.OP_BXOR: OpCode.BXOR,
        TokenKind.OP_SHL: OpCode.SHL,
        TokenKind.OP_SHR: OpCode.SHR,
    }

    def __init__(self, parent, fd):
        self.parent = parent
        self.sub_funcs = []
        self.used_regs = 0
        self.max_regs = 0
        self.scope_level = 0
        self.local_vars = []
        self.local_names = {}
        self.upvalues = {}
        self.constants = {}
        self.breaks = [None]
        self.insts = []
        self.num_params = 0 if fd.par_list is None else len(fd.par_list)
        self.is_vararg = fd.is_var_arg

    def index_of_constant(self, k):
        if k in self.constants:
            return self.constants[k]

        idx = len(self.constants)
        self.constants[k] = idx
        return idx

    def alloc_reg(self):
        self.used_regs += 1
        if self.used_regs >= 255:
            raise Exception('function or expression needs too many registers')

        if self.used_regs > self.max_regs:
            self.max_regs = self.used_regs

        return self.used_regs - 1

    def free_reg(self):
        self.used_regs -= 1

    def alloc_regs(self, n):
        for i in range(n):
            self.alloc_reg()
        return self.used_regs - n

    def free_regs(self, n):
        for i in range(n):
            self.free_reg()

    def enter_scope(self, break_able):
        self.scope_level += 1
        if break_able:
            self.breaks.append([])
        else:
            self.breaks.append(None)

    def exit_scope(self):
        if len(self.breaks) > 0:
            pending_break_jmps = self.breaks[-1]
            self.breaks = self.breaks[:len(self.breaks)-1]
            if pending_break_jmps:
                a = self.get_jmp_arg_a()
                for pc in pending_break_jmps:
                    sbx = self.pc() - pc
                    i = ((sbx + Instruction.MAXARG_sBx) << 14) | (a << 6) | OpCode.JMP
                    self.insts[pc] = i

        self.scope_level -= 1
        for k in list(self.local_names):
            local_var = self.local_names[k]
            if local_var.scope_level > self.scope_level:
                self.remove_local_var(local_var)

    def add_local_var(self, name):
        prev = None
        if name in self.local_names:
            prev = self.local_names[name]
        new_var = LocalVarInfo(name, prev, self.scope_level, self.alloc_reg())

        self.local_vars.append(new_var)
        self.local_names[name] = new_var
        return new_var.slot

    def remove_local_var(self, local_var):
        self.free_reg()
        if local_var.prev is None:
            self.local_names.pop(local_var.name)
        elif local_var.prev.scope_level == local_var.scope_level:
            self.remove_local_var(local_var.prev)
        else:
            self.local_names[local_var.name] = local_var.prev

    def slot_of_local_var(self, name):
        if name in self.local_names:
            return self.local_names[name].slot
        return -1

    def add_break_jmp(self, pc):
        for i in range(self.scope_level, -1, -1):
            if self.breaks[i] is not None:
                self.breaks[i].append(pc)
                return

        raise Exception('<break> not inside a loop!')

    def index_of_upval(self, name):
        if name in self.upvalues:
            return self.upvalues[name].index

        if self.parent is not None:
            if name in self.parent.local_names:
                idx = len(self.upvalues)
                local_var = self.parent.local_names[name]
                self.upvalues[name] = UpvalInfo(local_var.slot, -1, idx)
                local_var.captured = True
                return idx

            upval_index = self.parent.index_of_upval(name)
            if upval_index >= 0:
                idx = len(self.upvalues)
                self.upvalues[name] = UpvalInfo(-1, upval_index, idx)
                return idx

        return -1

    def close_open_upvals(self):
        a = self.get_jmp_arg_a()
        if a > 0:
            self.emit_jmp(a, 0)

    def get_jmp_arg_a(self):
        has_captured_local_vars = False
        min_slot_of_local_vars = self.max_regs
        for local_var in self.local_names.values():
            if local_var.scope_level == self.scope_level:
                v = local_var
                while v is not None and v.scope_level == self.scope_level:
                    if v.captured:
                        has_captured_local_vars = True
                    if v.slot < min_slot_of_local_vars and v.name[0] == '(':
                        min_slot_of_local_vars = v.slot
                    v = v.prev

        if has_captured_local_vars:
            return min_slot_of_local_vars + 1

        return 0

    def pc(self):
        return len(self.insts) - 1

    def fix_sbx(self, pc, sbx):
        i = self.insts[pc]
        i = ((i << 18) & 0xffffffff) >> 18
        i = (i | (sbx + Instruction.MAXARG_sBx) << 14) & 0xffffffff
        self.insts[pc] = i

    def fix_end_pc(self, name, delta):
        for i in range(len(self.local_vars), -1, -1):
            local_var = self.local_vars[i]
            if local_var.name == name:
                local_var.end_pc += delta
                return

    def emit_abc(self, opcode, a, b, c):
        print("%5s %8d %8d %8d %8d" % ('ABC', opcode, a, b, c))
        i = ((b << 23) | (c << 14) | (a << 6) | opcode) & 0xffffffff
        self.insts.append(i)

    def emit_a_bx(self, opcode, a, bx):
        print("%5s %8d %8d %8d" % ('ABx', opcode, a, bx))
        i = ((bx << 14) | (a << 6) | opcode) & 0xffffffff
        self.insts.append(i)

    def emit_as_bx(self, opcode, a, sbx):
        print("%5s %8d %8d %8d" % ('AsBx', opcode, a, sbx))
        i = (((sbx + Instruction.MAXARG_sBx) << 14) | (a << 6) | opcode) & 0xffffffff
        self.insts.append(i)

    def emit_ax(self, opcode, ax):
        print("%5s %8d %8d" % ('AX', opcode, ax))
        i = ((ax << 6) | opcode) & 0xffffffff
        self.insts.append(i)

    def emit_move(self, a, b):
        self.emit_abc(OpCode.MOVE, a, b, 0)

    def emit_load_nil(self, a, n):
        self.emit_abc(OpCode.LOADNIL, a, n-1, 0)

    def emit_load_bool(self, a, b, c):
        self.emit_abc(OpCode.LOADBOOL, a, b, c)

    def emit_load_k(self, a, k):
        idx = self.index_of_constant(k)
        if idx < (1 << 18):
            self.emit_a_bx(OpCode.LOADK, a, idx)
        else:
            self.emit_a_bx(OpCode.LOADKX, a, 0)
            self.emit_ax(OpCode.EXTRAARG, idx)

    def emit_vararg(self, a, n):
        self.emit_abc(OpCode.VARARG, a, n+1, 0)

    def emit_closure(self, a, bx):
        self.emit_a_bx(OpCode.CLOSURE, a, bx)

    def emit_new_table(self, a, narr, nrec):
        self.emit_abc(OpCode.NEWTABLE, a, LuaValue.int2fb(narr), LuaValue.int2fb(nrec))

    def emit_set_list(self, a, b, c):
        self.emit_abc(OpCode.SETLIST, a, b, c)

    def emit_get_table(self, a, b, c):
        self.emit_abc(OpCode.GETTABLE, a, b, c)

    def emit_set_table(self, a, b, c):
        self.emit_abc(OpCode.SETTABLE, a, b, c)

    def emit_get_upval(self, a, b):
        self.emit_abc(OpCode.GETUPVAL, a, b, 0)

    def emit_set_upval(self, a, b):
        self.emit_abc(OpCode.SETUPVAL, a, b, 0)

    def emit_get_tabup(self, a, b, c):
        self.emit_abc(OpCode.GETTABUP, a, b, c)

    def emit_set_tabup(self, a, b, c):
        self.emit_abc(OpCode.SETTABUP, a, b, c)

    def emit_call(self, a, nargs, nret):
        self.emit_abc(OpCode.CALL, a, nargs+1, nret+1)

    def emit_tail_call(self, a, nargs):
        self.emit_abc(OpCode.TAILCALL, a, nargs+1, 0)

    def emit_return(self, a, n):
        self.emit_abc(OpCode.RETURN, a, n+1, 0)

    def emit_self(self, a, b, c):
        self.emit_abc(OpCode.SELF, a, b, c)

    def emit_jmp(self, a, sbx):
        self.emit_as_bx(OpCode.JMP, a, sbx)
        return len(self.insts) - 1

    def emit_test(self, a, c):
        self.emit_abc(OpCode.TEST, a, 0, c)

    def emit_test_set(self, a, b, c):
        self.emit_abc(OpCode.TESTSET, a, b, c)

    def emit_for_prep(self, a, sbx):
        self.emit_as_bx(OpCode.FORPREP, a, sbx)
        return len(self.insts) - 1

    def emit_for_loop(self, a, sbx):
        self.emit_as_bx(OpCode.FORLOOP, a, sbx)
        return len(self.insts) - 1

    def emit_tfor_call(self, a, c):
        self.emit_abc(OpCode.TFORCALL, a, 0, c)

    def emit_tfor_loop(self, a, sbx):
        self.emit_as_bx(OpCode.TFORLOOP, a, sbx)

    def emit_unary_op(self, op, a, b):
        if op == TokenKind.OP_NOT:
            self.emit_abc(OpCode.NOT, a, b, 0)
        elif op == TokenKind.OP_BNOT:
            self.emit_abc(OpCode.BNOT, a, b, 0)
        elif op == TokenKind.OP_LEN:
            self.emit_abc(OpCode.LEN, a, b, 0)
        elif op == TokenKind.OP_UNM:
            self.emit_abc(OpCode.UNM, a, b, 0)

    def emit_binary_op(self, op, a, b, c):
        if op in FuncInfo.arith_and_bitwise_binops:
            self.emit_abc(FuncInfo.arith_and_bitwise_binops[op], a, b, c)
        else:
            if op == TokenKind.OP_EQ:
                self.emit_abc(OpCode.EQ, 1, b, c)
            elif op == TokenKind.OP_NE:
                self.emit_abc(OpCode.EQ, 0, b, c)
            elif op == TokenKind.OP_LT:
                self.emit_abc(OpCode.LT, 1, b, c)
            elif op == TokenKind.OP_GT:
                self.emit_abc(OpCode.LT, 1, c, b)
            elif op == TokenKind.OP_LE:
                self.emit_abc(OpCode.LE, 1, b, c)
            elif op == TokenKind.OP_GE:
                self.emit_abc(OpCode.LE, 1, c, b)

            self.emit_jmp(0, 1)
            self.emit_load_bool(a, 0, 1)
            self.emit_load_bool(a, 1, 0)

    def to_proto(self):
        proto = Prototype()
        proto.num_params = self.num_params
        proto.max_stack_size = self.max_regs
        proto.code = self.insts
        proto.constants = self.get_constants()
        proto.upvalues = self.get_upvalues()
        proto.protos = []
        proto.line_infos = []
        proto.local_vars = self.local_vars
        proto.upvalue_names = self.get_upvalue_names()

        for fi in self.sub_funcs:
            proto.protos.append(fi.to_proto())

        if proto.get_max_stack_size() < 2:
            proto.max_stack_size = 2
        if self.is_vararg:
            proto.is_vararg = 1
        return proto

    def get_upvalues(self):
        upvals = [None for _ in range(len(self.upvalues))]
        for _, upval_info in self.upvalues.items():
            if upval_info.local_var_slot >= 0:
                upval = Upvalue(True, upval_info.local_var_slot)
            else:
                upval = Upvalue(False, upval_info.upval_index)
            upvals[upval_info.index] = upval
        return upvals

    def get_upvalue_names(self):
        names = ['' for _ in range(len(self.upvalues))]
        for name, upval_info in self.upvalues.items():
            names[upval_info.index] = name

        return names

    def get_constants(self):
        consts = {}
        for k, idx in self.constants.items():
            consts[idx] = k
        return consts

2.exp

from lua_exp import *
from func_info import FuncInfo
from lua_token import TokenKind
from lua_opcode import OpCode


class ArgAndKind:
    def __init__(self, arg, kind):
        self.arg = arg
        self.kind = kind


class CodegenExp:
    @staticmethod
    def process_exp(fi, exp, a, n):
        if isinstance(exp, NilExp):
            fi.emit_load_nil(a, n)
        elif isinstance(exp, FalseExp):
            fi.emit_load_bool(a, 0, 0)
        elif isinstance(exp, TrueExp):
            fi.emit_load_bool(a, 1, 0)
        elif isinstance(exp, IntegerExp):
            fi.emit_load_k(a, exp.val)
        elif isinstance(exp, FloatExp):
            fi.emit_load_k(a, exp.val)
        elif isinstance(exp, StringExp):
            fi.emit_load_k(a, exp.s)
        elif isinstance(exp, ParensExp):
            CodegenExp.process_exp(fi, exp.exp, a, 1)
        elif isinstance(exp, VarArgExp):
            CodegenExp.process_vararg_exp(fi, a, n)
        elif isinstance(exp, FuncDefExp):
            CodegenExp.process_func_def_exp(fi, exp, a)
        elif isinstance(exp, TableConstructorExp):
            CodegenExp.process_table_constructor_exp(fi, exp, a)
        elif isinstance(exp, UnopExp):
            CodegenExp.process_unop_exp(fi, exp, a)
        elif isinstance(exp, BinopExp):
            CodegenExp.process_binop_exp(fi, exp, a)
        elif isinstance(exp, ConcatExp):
            CodegenExp.process_concat_exp(fi, exp, a)
        elif isinstance(exp, NameExp):
            CodegenExp.process_name_exp(fi, exp, a)
        elif isinstance(exp, TableAccessExp):
            CodegenExp.process_table_access_exp(fi, exp, a)
        elif isinstance(exp, FuncCallExp):
            CodegenExp.process_func_call_exp(fi, exp, a, n)
staticmethod
    def process_vararg_exp(fi, a, n):
        if not fi.is_vararg:
            raise Exception('cannot use "..." outside a vararg function')
        fi.emit_vararg(a, n)

    @staticmethod
    def process_func_def_exp(fi, exp, a):
        from codegen_block import CodegenBlock
        sub_fi = FuncInfo(fi, exp)
        fi.sub_funcs.append(sub_fi)

        if exp.par_list is not None:
            for param in exp.par_list:
                sub_fi.add_local_var(param)

        CodegenBlock.gen_block(sub_fi, exp.block)
        sub_fi.exit_scope()
        sub_fi.emit_return(0, 0)

        bx = len(fi.sub_funcs) - 1
        fi.emit_closure(a, bx)

    @staticmethod
    def process_table_constructor_exp(fi, exp, a):
        narr = 0
        for key_exp in exp.key_exps:
            if key_exp is None:
                narr += 1

        nexps = len(exp.key_exps)
        mult_ret = nexps > 0 and ExpHelper.is_vararg_or_func_call(exp.val_exps[-1])
        fi.emit_new_table(a, narr, nexps-narr)

        arr_idx = 0
        for i in range(len(exp.key_exps)):
            key_exp = exp.key_exps[i]
            val_exp = exp.val_exps[i]

            if key_exp is None:
                arr_idx += 1
                tmp = fi.alloc_reg()
                if i == nexps - 1 and mult_ret:
                    CodegenExp.process_exp(fi, val_exp, tmp, -1)
                else:
                    CodegenExp.process_exp(fi, val_exp, tmp, 1)

                if arr_idx % 50 == 0 or arr_idx == narr:
                    n = arr_idx % 50
                    if n == 0:
                        n = 50

                    fi.free_regs(n)
                    c = (arr_idx - 1) // 50 + 1
                    if i == nexps - 1 and mult_ret:
                        fi.emit_set_list(a, 0, c)
                    else:
                        fi.emit_set_list(a, n, c)
                continue

            b = fi.alloc_reg()
            CodegenExp.process_exp(fi, key_exp, b, 1)
            c = fi.alloc_reg()
            CodegenExp.process_exp(fi, val_exp, c, 1)
            fi.free_regs(2)

            fi.emit_set_table(a, b, c)

    @staticmethod
    def process_unop_exp(fi, exp, a):
        b = fi.alloc_reg()
        CodegenExp.process_exp(fi, exp.exp, b, 1)
        fi.emit_unary_op(exp.op, a, b)
        fi.free_reg()

    @staticmethod
    def process_binop_exp(fi, exp, a):
        if exp.op == TokenKind.OP_AND or exp.op == TokenKind.OP_OR:
            b = fi.alloc_reg()
            CodegenExp.process_exp(fi, exp.exp1, b, 1)
            fi.free_reg()
            if exp.op == TokenKind.OP_AND:
                fi.emit_test_set(a, b, 0)
            else:
                fi.emit_test_set(a, b, 1)
            pc_of_jmp = fi.emit_jmp(0, 0)

            b = fi.alloc_reg()
            CodegenExp.process_exp(fi, exp.exp2, b, 1)
            fi.free_reg()
            fi.emit_move(a, b)
            fi.fix_sbx(pc_of_jmp, fi.pc()-pc_of_jmp)
        else:
            b = fi.alloc_reg()
            CodegenExp.process_exp(fi, exp.exp1, b, 1)
            c = fi.alloc_reg()
            CodegenExp.process_exp(fi, exp.exp2, c, 1)
            fi.emit_binary_op(exp.op, a, b, c)
            fi.free_regs(2)

    @staticmethod
    def process_concat_exp(fi, exp, a):
        for sub_exp in exp.exps:
            a1 = fi.alloc_reg()
            CodegenExp.process_exp(fi, sub_exp, a1, 1)

        c = fi.used_regs - 1
        b = c - len(exp.exps) + 1
        fi.free_regs(c - b + 1)
        fi.emit_abc(OpCode.CONCAT, a, b, c)

    @staticmethod
    def process_name_exp(fi, exp, a):
        r = fi.slot_of_local_var(exp.name)
        if r >= 0:
            fi.emit_move(a, r)
            return

        idx = fi.index_of_upval(exp.name)
        if idx >= 0:
            fi.emit_get_upval(a, idx)
            return

        prefix_exp = NameExp(exp.line, '_ENV')
        key_exp = StringExp(exp.line, exp.name)
        table_access_exp = TableAccessExp(exp.line, prefix_exp, key_exp)
        CodegenExp.process_table_access_exp(fi, table_access_exp, a)

    @staticmethod
    def process_table_access_exp(fi, exp, a):
        b = fi.alloc_reg()
        CodegenExp.process_exp(fi, exp.prefix_exp, b, 1)
        c = fi.alloc_reg()
        CodegenExp.process_exp(fi, exp.key_exp, c, 1)
        fi.emit_get_table(a, b, c)
        fi.free_regs(2)

    @staticmethod
    def process_func_call_exp(fi, exp, a, n):
        nargs = CodegenExp.process_prep_func_call(fi, exp, a)
        fi.emit_call(a, nargs, n)

    @staticmethod
    def process_tail_call_exp(fi, exp, a):
        nargs = CodegenExp.process_prep_func_call(fi, exp, a)
        fi.emit_tail_call(a, nargs)

    @staticmethod
    def process_prep_func_call(fi, exp, a):
        nargs = len(exp.args)
        last_arg_is_vararg_or_fkunc_call = False

        CodegenExp.process_exp(fi, exp.prefix_exp, a, 1)
        if exp.name_exp is not None:
            c = 0x100 + fi.index_of_constant(exp.name_exp.s)
            fi.emit_self(a, a, c)

        for i in range(len(exp.args)):
            arg = exp.args[i]
            tmp = fi.alloc_reg()
            if i == nargs - 1 and ExpHelper.is_vararg_or_func_call(arg):
                last_arg_is_vararg_or_fkunc_call = True
                CodegenExp.process_exp(fi, arg, tmp, -1)
            else:
                CodegenExp.process_exp(fi, arg, tmp, 1)
        fi.free_regs(nargs)

        if exp.name_exp is not None:
            nargs += 1
        if last_arg_is_vararg_or_fkunc_call:
            nargs = -1

        return nargs

3.stat

from lua_stat import *
from lua_exp import *
from codegen_exp import CodegenExp


class CodegenStat:
    @staticmethod
    def process(fi, stat):
        if isinstance(stat, FuncCallStat):
            CodegenStat.process_func_call_stat(fi, stat)
        elif isinstance(stat, BreakStat):
            CodegenStat.process_break_stat(fi)
        elif isinstance(stat, DoStat):
            CodegenStat.process_do_stat(fi, stat)
        elif isinstance(stat, WhileStat):
            CodegenStat.process_while_stat(fi, stat)
        elif isinstance(stat, RepeatStat):
            CodegenStat.process_repeat_stat(fi, stat)
        elif isinstance(stat, IfStat):
            CodegenStat.process_if_stat(fi, stat)
        elif isinstance(stat, ForNumStat):
            CodegenStat.process_for_num_stat(fi, stat)
        elif isinstance(stat, ForInStat):
            CodegenStat.process_for_in_stat(fi, stat)
        elif isinstance(stat, AssignStat):
            CodegenStat.process_assign_stat(fi, stat)
        elif isinstance(stat, LocalVarDeclStat):
            CodegenStat.process_local_var_decl_stat(fi, stat)
        elif isinstance(stat, LocalFuncDefStat):
            CodegenStat.process_local_func_def_stat(fi, stat)
        elif isinstance(stat, (LabelStat, GotoStat)):
            raise Exception('label and goto are not supported!')

    @staticmethod
    def process_local_func_def_stat(fi, stat):
        r = fi.add_local_var(stat.name, fi.pc() + 1)
        CodegenExp.process_func_def_exp(fi, stat.exp, r)

    @staticmethod
    def process_func_call_stat(fi, stat):
        r = fi.alloc_reg()
        CodegenExp.process_func_call_exp(fi, stat.exp, r, 0)
        fi.free_reg()

    @staticmethod
    def process_break_stat(fi):
        pc = fi.emit_jmp(0, 0)
        fi.add_break_jmp(pc)

    @staticmethod
    def process_do_stat(fi, stat):
        fi.enter_scope(False)
        from codegen_block import CodegenBlock
        CodegenBlock.gen_block(fi, stat.block)
        fi.close_open_upvals()
        fi.exit_scope(fi.pc() - 1)

    @staticmethod
    def process_while_stat(fi, stat):
        pc_before_exp = fi.pc()

        r = fi.alloc_reg()
        CodegenExp.process_exp(fi, stat.exp, r, 1)
        fi.free_reg()

        fi.emit_test(r, 0)
        pc_jmp_to_end = fi.emit_jmp(0, 0)

        fi.enter_scope(True)
        from codegen_block import CodegenBlock
        CodegenBlock.gen_block(fi, stat.block)
        fi.close_open_upvals()
        fi.emit_jmp(0, pc_before_exp - fi.pc() - 1)
        fi.exit_scope()

        fi.fix_sbx(pc_jmp_to_end, fi.pc() - pc_jmp_to_end)

    @staticmethod
    def process_repeat_stat(fi, stat):
        fi.enter_scope(True)

        pc_before_block = fi.pc()
        from codegen_block import CodegenBlock
        CodegenBlock.gen_block(fi, stat.block)

        r = fi.alloc_reg()
        CodegenExp.process_exp(fi, stat.exp, r, 1)
        fi.free_reg()

        fi.emit_test(r, 0)
        fi.emit_jmp(fi.get_jmp_arg_a(), pc_before_block-fi.pc()-1)
        fi.close_open_upvals()

        fi.exit_scope()

    @staticmethod
    def process_if_stat(fi, stat):
        pc_jmp_to_ends = []
        pc_jmp_to_next_exp = -1

        for i, exp in enumerate(stat.exps):
            if pc_jmp_to_next_exp >= 0:
                fi.fix_sbx(pc_jmp_to_next_exp, fi.pc() - pc_jmp_to_next_exp)

            r = fi.alloc_reg()
            CodegenExp.process_exp(fi, exp, r, 1)
            fi.free_reg()

            fi.emit_test(r, 0)
            pc_jmp_to_next_exp = fi.emit_jmp(0, 0)

            fi.enter_scope(False)
            from codegen_block import CodegenBlock
            CodegenBlock.gen_block(fi, stat.blocks[i])
            fi.close_open_upvals()
            fi.exit_scope()

            if i < len(stat.exps)-1:
                pc_jmp_to_ends.append(fi.emit_jmp(0, 0))
            else:
                pc_jmp_to_ends.append(pc_jmp_to_next_exp)

        for pc in pc_jmp_to_ends:
            fi.fix_sbx(pc, fi.pc() - pc)

    @staticmethod
    def process_for_num_stat(fi, stat):
        fi.enter_scope(True)
        local_var_stat = LocalVarDeclStat(0,
                                          ['(for index)', '(for limit)', '(for step)'],
                                          [stat.init_exp, stat.limit_exp, stat.step_exp])
        CodegenStat.process_local_var_decl_stat(fi, local_var_stat)
        fi.add_local_var(stat.var_name)

        a = fi.used_regs - 4
        pc_for_prep = fi.emit_for_prep(a, 0)
        from codegen_block import CodegenBlock
        CodegenBlock.gen_block(fi, stat.block)
        fi.close_open_upvals()
        pc_for_loop = fi.emit_for_loop(a, 0)

        fi.fix_sbx(pc_for_prep, pc_for_loop-pc_for_prep-1)
        fi.fix_sbx(pc_for_loop, pc_for_prep-pc_for_loop)

        fi.exit_scope()

    @staticmethod
    def process_for_in_stat(fi, stat):
        fi.enter_scope(True)

        local_var = LocalVarDeclStat(0,
                                     ['(for generator)', '(for state)', '(for control)'],
                                     stat.exp_list)
        CodegenStat.process_local_var_decl_stat(fi, local_var)
        for name in stat.name_list:
            fi.add_local_var(name)

        pc_jmp_to_tfc = fi.emit_jmp(0, 0)
        from codegen_block import CodegenBlock
        CodegenBlock.gen_block(fi, stat.block)
        fi.close_open_upvals()
        fi.fix_sbx(pc_jmp_to_tfc, fi.pc()-pc_jmp_to_tfc)

        r = fi.slot_of_local_var('(for generator)')
        fi.emit_tfor_call(r, len(stat.name_list))
        fi.emit_tfor_loop(r + 2, pc_jmp_to_tfc - fi.pc() - 1)

        fi.exit_scope()

    @staticmethod
    def process_local_var_decl_stat(fi, stat):
        exps = ExpHelper.remove_tail_nils(stat.exp_list)
        nexps = len(exps)
        nnames = len(stat.name_list)

        old_regs = fi.used_regs
        if nexps == nnames:
            for exp in exps:
                a = fi.alloc_reg()
                CodegenExp.process_exp(fi, exp, a, 1)
        elif nexps > nnames:
            for i in range(nexps):
                exp = exps[i]
                a = fi.alloc_reg()
                if i == nexps-1 and ExpHelper.is_vararg_or_func_call(exp):
                    CodegenExp.process_exp(fi, exp, a, 0)
                else:
                    CodegenExp.process_exp(fi, exp, a, 1)
        else:
            mult_ret = False
            for i in range(nexps):
                exp = exps[i]
                a = fi.alloc_reg()
                if i == nexps-1 and ExpHelper.is_vararg_or_func_call(exp):
                    mult_ret = True
                    n = nnames - nexps + 1
                    CodegenExp.process_exp(fi, exp, a, n)
                    fi.alloc_regs(n-1)
                else:
                    CodegenExp.process_exp(fi, exp, a, 1)

            if not mult_ret:
                n = nnames - nexps
                a = fi.alloc_regs(n)
                fi.emit_load_nil(a, n)

        fi.used_regs = old_regs
        for name in stat.name_list:
            fi.add_local_var(name)

    @staticmethod
    def process_assign_stat(fi, stat):
        exps = ExpHelper.remove_tail_nils(stat.exp_list)
        nexps = len(exps)
        nvars = len(stat.var_list)

        tregs = [0 for _ in range(nvars)]
        kregs = [0 for _ in range(nvars)]
        vregs = [0 for _ in range(nvars)]
        old_regs = fi.used_regs

        for i in range(len(stat.var_list)):
            exp = stat.var_list[i]
            if isinstance(exp, TableAccessExp):
                tregs[i] = fi.alloc_reg()
                CodegenExp.process_exp(fi, exp.prefix_exp, tregs[i], 1)
                kregs[i] = fi.alloc_reg()
                CodegenExp.process_exp(fi, exp.key_exp, kregs[i], 1)
            else:
                name = exp.name
                if fi.slot_of_local_var(name) < 0 and fi.index_of_upval(name) < 0:
                    kregs[i] = -1
                    if fi.index_of_constant(name) > 0xff:
                        kregs[i] = fi.alloc_reg()

        for i in range(nvars):
            vregs[i] = fi.used_regs + i

        if nexps >= nvars:
            for i in range(nexps):
                exp = exps[i]
                a = fi.alloc_reg()
                if i >= nvars and i == nexps-1 and ExpHelper.is_vararg_or_func_call(exp):
                    CodegenExp.process_exp(fi, exp, a, 0)
                else:
                    CodegenExp.process_exp(fi, exp, a, 1)
        else:
            mult_ret = False
            for i in range(nexps):
                exp = exps[i]
                a = fi.alloc_reg()
                if i == nexps-1 and ExpHelper.is_vararg_or_func_call(exp):
                    mult_ret = True
                    n = nvars - nexps + 1
                    CodegenExp.process_exp(fi, exp, a, n)
                    fi.alloc_regs(n-1)
                else:
                    CodegenExp.process_exp(fi, exp, a, 1)

            if not mult_ret:
                n = nvars - nexps
                a = fi.alloc_regs(n)
                fi.emit_load_nil(a, n)

        for i in range(nvars):
            exp = stat.var_list[i]
            if not isinstance(exp, NameExp):
                fi.emit_set_table(tregs[i], kregs[i], vregs[i])
                continue

            var_name = exp.name
            a = fi.slot_of_local_var(var_name)
            if a >= 0:
                fi.emit_move(a, vregs[i])
                continue

            b = fi.index_of_upval(var_name)
            if b >= 0:
                fi.emit_set_upval(vregs[i], b)
                continue

            a = fi.slot_of_local_var('_ENV')
            if a >= 0:
                if kregs[i] < 0:
                    b = 0x100 + fi.index_of_constant(var_name)
                    fi.emit_set_table(a, b, vregs[i])
                else:
                    fi.emit_set_table(a, kregs[i], vregs[i])
                continue

            a = fi.index_of_upval('_ENV')
            if kregs[i] < 0:
                b = 0x100 + fi.index_of_constant(var_name)
                fi.emit_set_tabup(a, b, vregs[i])
            else:
                fi.emit_set_tabup(a, kregs[i], vregs[i])

        fi.used_regs = old_regs

4.block

from codegen_stat import CodegenStat
from codegen_exp import CodegenExp
from lua_exp import *


class CodegenBlock:
    @staticmethod
    def gen_block(funcinfo, block):
        for stat in block.stats:
            CodegenStat.process(funcinfo, stat)

        if block.ret_exps is not None:
            CodegenBlock.process_ret_stat(funcinfo, block.ret_exps)

    @staticmethod
    def process_ret_stat(fi, exps):
        nexps = len(exps)
        if nexps == 0:
            fi.emit_return(0, 0)
            return

        if nexps == 1:
            if isinstance(exps[0], NameExp):
                name_exp = exps[0]
                r = fi.slot_of_local_var(name_exp.name)
                if r >= 0:
                    fi.emit_return(r, 1)
                    return
            if isinstance(exps[0], FuncCallExp):
                func_exp = exps[0]
                r = fi.alloc_reg()
                CodegenExp.process_tail_call_exp(fi, func_exp, r)
                fi.free_reg()
                fi.emit_return(r, -1)
                return

        mult_ret = ExpHelper.is_vararg_or_func_call(exps[-1])
        for i in range(nexps):
            exp = exps[i]
            r = fi.alloc_reg()
            if i == nexps - 1 and mult_ret:
                CodegenExp.process_exp(fi, exp, r, -1)
            else:
                CodegenExp.process_exp(fi, exp, r, 1)
        fi.free_regs(nexps)

        a = fi.used_regs
        if mult_ret:
            fi.emit_return(a, -1)
        else:
            fi.emit_return(a, nexps)

5.code gen

from lua_exp import FuncDefExp
from func_info import FuncInfo
from codegen_exp import CodegenExp


class Codegen:
    @staticmethod
    def gen_proto(chunk):
        func_def_exp = FuncDefExp(0, chunk.last_line, [], True, chunk)
        func_info = FuncInfo(None, func_def_exp)
        func_info.add_local_var('_ENV')

        CodegenExp.process_func_def_exp(func_info, func_def_exp, 0)
        return func_info.sub_funcs[0].to_proto()


6.compiler

from parser import Parser
from lexer import Lexer
from codegen import Codegen


class LuaCompiler:
    @staticmethod
    def compile(chunk, chunk_name):
        parser = Parser()
        lexer = Lexer(chunk, chunk_name)
        ast = parser.parse_block(lexer)
        print(ast)
        proto = Codegen.gen_proto(ast)
        proto.print_code()
        LuaCompiler.set_source(proto, chunk_name)
        return proto

    @staticmethod
    def set_source(proto, chunk_name):
        proto.source = chunk_name
        for sub in proto.get_protos():
            LuaCompiler.set_source(sub, chunk_name)

7.test

from lua_state import LuaState
from lua_type import LuaType
from thread_state import ThreadStatus
import sys


def py_print(ls):
    nargs = ls.get_top()
    for i in range(1, nargs+1):
        if ls.is_boolean(i):
            print('true' if ls.to_boolean(i) else 'false', end='')
        elif ls.is_string(i):
            print(ls.to_string(i), end='')
        else:
            print(ls.type_name(ls.type(i)), end='')

        if i < nargs:
            print('\t', end='')

    print()
    return 0


def get_metatable(ls):
    if not ls.get_metatable:
        ls.push_nil()
    return 1


def set_metatable(ls):
    ls.set_metatable(1)
    return 1


def lua_next(ls):
    ls.set_top(2)
    if ls.next(1):
        return 2
    else:
        ls.push_nil()
        return 1


def pairs(ls):
    ls.push_py_function(lua_next)
    ls.push_value(1)
    ls.push_nil()
    return 3


def ipairs(ls):
    ls.push_py_function(ipairs_aux)
    ls.push_value(1)
    ls.push_integer(0)
    return 3


def ipairs_aux(ls):
    i = ls.to_integer(2) + 1
    ls.push_integer(i)
    if ls.get_i(1, i) == LuaType.NIL:
        return 1
    else:
        return 2


def error(ls):
    return ls.error()


def pcall(ls):
    nargs = ls.get_top() - 1
    status = ls.pcall(nargs, -1, 0)
    ls.push_boolean(status == ThreadStatus.OK)
    ls.insert(1)
    return ls.get_top()


def main():
    ls = LuaState()
    ls.register('print', py_print)
    ls.register('getmetatable', get_metatable)
    ls.register('setmetatable', set_metatable)
    ls.register('next', lua_next)
    ls.register('pairs', pairs)
    ls.register('ipairs', ipairs)
    ls.register('error', error)
    ls.register('pcall', pcall)

    ls.load(sys.argv[1])
    ls.call(0, 0)


if __name__ == '__main__':
    if len(sys.argv) == 2:
        main()
    else:
        print('Error argument')

  1. hello world
print('hello world')
"LastLine": 1,
"Stats": [{
  <lua_stat.FuncCallStat object at 0x7fbb7b42c1d0>
}]
"RetExps": 
  nil
  ABC        5        1        0        0
  ABx        1        2        0
  ABC        7        0        1        2
  ABx        1        1        1
  ABC       36        0        2        1
  ABC       38        0        1        0
  ABx       44        0        0
	1	[-]	GETUPVAL        1        0
	2	[-]	LOADK           2       -1
	3	[-]	GETTABLE        0        1        2
	4	[-]	LOADK           1       -2
	5	[-]	CALL            0        2        1
	6	[-]	RETURN          0        1
env: {'print': <closure.Closure object at 0x7fbb7b457c88>, 'getmetatable': <closure.Closure object at 0x7fbb7b47d4e0>, 'setmetatable': <closure.Closure object at 0x7fbb7b47d470>, 'next': <closure.Closure object at 0x7fbb7b47d6a0>, 'pairs': <closure.Closure object at 0x7fbb7b48eb38>, 'ipairs': <closure.Closure object at 0x7fbb7b48eb70>, 'error': <closure.Closure object at 0x7fbb7b48eba8>, 'pcall': <closure.Closure object at 0x7fbb7b48ebe0>}
(  0) [01] GETUPVAL     [nil][table][nil]
(  1) [02] LOADK        [nil][table]["print"]
(  2) [03] GETTABLE     [function][table]["print"]
(  3) [04] LOADK        [function]["hello world"]["print"]
hello world
(  4) [05] CALL         [function]["hello world"]["print"]
(  5) [06] RETURN       [function]["hello world"]["print"]
  1. a lot test
print('hello world')


a = 1
print(a)

a = a+1
print(a)

b = a*2
print(b)

print(a < b)
print(b, b==a*2, b<=a*2, b/3, b//3, b%3, 2^53, 2~=2, 2~=3)
print(b<<2, b>>2, b&4, b&3)

print(4.1e-3, 5E+20, 1==1.0, 0xff)
print(2 and 3)
print(0 and 3)
print(nil and 7)
print(false and 7)
print(1 or 2, 3 or 4, false or 5, true or 6, nil or 7, 8 or nil)

print(3 or 4)
a = 3
b = 4
print(a or b)

print(not 2)

a = 'abc'
print(a)
print(#a)
b = 'abc' .. 'def'
print(b)
c = 'xxx'
d = b .. c
print(d)

a = {}
k = "x"
a[k] = 10
a[20] = "great"
print(a["x"])
k = 20
print(a[k])
a["x"] = a["x"]+1
print(a["x"])

a = {}
a["x"] = 10
b = a
print(b["x"])
b["x"]=20
print(a["x"])
a = nil
b = nil

a = {1, 2, 3, 4}
print(a[0], a[1], a[2], a[3], a[4])

a = {}
a.x = 10
print(a.x)
print(a.y)

days = {"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}
print(days[4])


function add(a, b)
    return a+b
end

c = add(2, 3)
print(c)

a, b = 12, 34
print(add(a, b))



a, b = 1, 2
if a > 1 then
    c = a
else
    c = b
end

print(c)

function test(a)
    if a >= 90 then
        return 'A'
    elseif a >= 80 then
        return 'B'
    elseif a >= 70 then
        return 'C'
    else
        return 'D'
    end
end

print(test(100), test(90), test(80), test(70), test(60), test(30))


i = 1
sum = 0
while i <= 100 do
    sum = sum + i
    i = i + 1
end

print(sum)



i = 1
sum = 0

repeat
    sum = sum + i
    i = i + 1
until i > 100

print(sum)


sum = 0
for i = 0, 100, 1 do
    sum = sum + i
    if i == 10 then break end
end
print(sum)

function add(...)
    local s = 0
    for _, v in ipairs{...} do
        s = s + v
    end
    return s
end

print(add(3, 4, 10, 25, 12))



a = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}
for i = 1, 10 do a[i] = i*2 end
print(a[5])


function test()
    function inner(a)
        return a+1
    end
    print(inner(3))
end
test()

result

hello world
1
2
4
true
4	true	true	1.3333333333333333	1	1	9007199254740992.0	false	true
16	1	4	0
0.0041	5e+20	true	255
3
3
nil
false
1	3	5	true	7	8
3
3
false
abc
3
abcdef
abcdefxxx
10
great
11
10
20
nil	1	2	3	4
10
nil
Wednesday
5
46
2
A	A	B	C	D	D
5050
5050
55
54
10
4
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值