我们现在实现的Kaleidoscope还不够完善,缺少if else控制流,比如不支持如下代码
def fib(x)
if x < 3 then
1
else
fib(x - 1) + fib(x - 2)
首先让我们的Lexer能识别 if then else 三个关键字,增加TOKEN类型
TOKEN_IF = -6, // if
TOKEN_THEN = -7, // then
TOKEN_ELSE = -8, // else
增加识别规则
// 识别字符串
if (isalpha(last_char)) {
g_identifier_str = last_char;
while (isalnum((last_char = getchar()))) {
g_identifier_str += last_char;
}
if (g_identifier_str == "def") {
return TOKEN_DEF;
} else if (g_identifier_str == "extern") {
return TOKEN_EXTERN;
} else if (g_identifier_str == "if") {
return TOKEN_IF;
} else if (g_identifier_str == "then") {
return TOKEN_THEN;
} else if (g_identifier_str == "else") {
return TOKEN_ELSE;
} else {
return TOKEN_IDENTIFIER;
}
}
增加IfExprAST
// if then else
class IfExprAST : public ExprAST {
public:
IfExprAST(std::unique_ptr<ExprAST> cond, std::unique_ptr<ExprAST> then_expr,
std::unique_ptr<ExprAST> else_expr)
: cond_(std::move(cond)),
then_expr_(std::move(then_expr)),
else_expr_(std::move(else_expr)) {}
llvm::Value* CodeGen() override;
private:
std::unique_ptr<ExprAST> cond_;
std::unique_ptr<ExprAST> then_expr_;
std::unique_ptr<ExprAST> else_expr_;
};
增加对IfExprAST的解析
std::unique_ptr<ExprAST> ParseIfExpr() {
GetNextToken(); // eat if
std::unique_ptr<ExprAST> cond = ParseExpression();
GetNextToken(); // eat then
std::unique_ptr<ExprAST> then_expr = ParseExpression();
GetNextToken(); // eat else
std::unique_ptr<ExprAST> else_expr = ParseExpression();
return std::make_unique<IfExprAST>(std::move(cond), std::move(then_expr),
std::move(else_expr));
}
增加到ParsePrimary中
// primary
// ::= identifierexpr
// ::= numberexpr
// ::= parenexpr
std::unique_ptr<ExprAST> ParsePrimary() {
switch (g_current_token) {
case TOKEN_IDENTIFIER: return ParseIdentifierExpr();
case TOKEN_NUMBER: return ParseNumberExpr();
case '(': return ParseParenExpr();
case TOKEN_IF: return ParseIfExpr();
default: return nullptr;
}
}
完成了lex和parse,接下来是最有意思的codegen
llvm::Value* IfExprAST::CodeGen() {
llvm::Value* cond_value = cond_->CodeGen();
// 创建fcmp one指令, cond_value = (cond_value != 0.0)
// 转为1bit (bool)类型
cond_value = g_ir_builder.CreateFCmpONE(
cond_value, llvm::ConstantFP::get(g_llvm_context, llvm::APFloat(0.0)),
"ifcond");
// 在每个function内我们会创建一个block, 这里一定在这个block内,根据block得到
// 对应的上层function
llvm::Function* func = g_ir_builder.GetInsertBlock()->getParent();
// 为then else以及最后的final创建block
llvm::BasicBlock* then_block =
llvm::BasicBlock::Create(g_llvm_context, "then", func);
llvm::BasicBlock* else_block =
llvm::BasicBlock::Create(g_llvm_context, "else");
llvm::BasicBlock* final_block =
llvm::BasicBlock::Create(g_llvm_context, "ifcont");
// 创建跳转指令,根据cond_value选择then_block/else_block
g_ir_builder.CreateCondBr(cond_value, then_block, else_block);
// codegen then_block, 增加跳转final_block指令
g_ir_builder.SetInsertPoint(then_block);
llvm::Value* then_value = then_expr_->CodeGen();
g_ir_builder.CreateBr(final_block);
// then语句内可能会有嵌套的if/then/else, 在嵌套的codegen时,会改变当前的
// InsertBlock, 我们需要有最终结果的那个block作为这里的then_block
then_block = g_ir_builder.GetInsertBlock();
// 在这里才加入是为了让这个block位于上面的then里嵌套block的后面
func->getBasicBlockList().push_back(else_block);
// 与then类似
g_ir_builder.SetInsertPoint(else_block);
llvm::Value* else_value = else_expr_->CodeGen();
g_ir_builder.CreateBr(final_block);
else_block = g_ir_builder.GetInsertBlock();
// codegen final
func->getBasicBlockList().push_back(final_block);
g_ir_builder.SetInsertPoint(final_block);
llvm::PHINode* pn = g_ir_builder.CreatePHI(
llvm::Type::getDoubleTy(g_llvm_context), 2, "iftmp");
pn->addIncoming(then_value, then_block);
pn->addIncoming(else_value, else_block);
return pn;
}
这里使用了上一节SSA中提到的phi function.
输入
def foo(x)
if x < 3 then
1
else
foo(x - 1) + foo(x - 2)
foo(1)
foo(2)
foo(3)
foo(4)
得到输出
parsed a function definition
define double @foo(double %x) {
entry:
%cmptmp = fcmp ult double %x, 3.000000e+00
%booltmp = uitofp i1 %cmptmp to double
%ifcond = fcmp one double %booltmp, 0.000000e+00
br i1 %ifcond, label %then, label %else
then: ; preds = %entry
br label %ifcont
else: ; preds = %entry
%subtmp = fsub double %x, 1.000000e+00
%calltmp = call double @foo(double %subtmp)
%subtmp1 = fsub double %x, 2.000000e+00
%calltmp2 = call double @foo(double %subtmp1)
%addtmp = fadd double %calltmp, %calltmp2
br label %ifcont
ifcont: ; preds = %else, %then
%iftmp = phi double [ 1.000000e+00, %then ], [ %addtmp, %else ]
ret double %iftmp
}
parsed a top level expr
define double @__anon_expr() {
entry:
%calltmp = call double @foo(double 1.000000e+00)
ret double %calltmp
}
1
parsed a top level expr
define double @__anon_expr() {
entry:
%calltmp = call double @foo(double 2.000000e+00)
ret double %calltmp
}
1
parsed a top level expr
define double @__anon_expr() {
entry:
%calltmp = call double @foo(double 3.000000e+00)
ret double %calltmp
}
2
parsed a top level expr
define double @__anon_expr() {
entry:
%calltmp = call double @foo(double 4.000000e+00)
ret double %calltmp
}
3
成功完成了斐波那契数列的计算。
接下来我们需要增加循环的支持,在此之前我们实现一个printd函数
extern "C" double printd(double x) {
printf("%lfn", x);
return 0.0;
}
编译
clang++ -g main.cpp `llvm-config --cxxflags --ldflags --libs` -Wl,-no-as-needed -rdynamic
输入
extern printd(x)
printd(12)
得到输出
parsed a extern
declare double @printd(double)
parsed a top level expr
define double @__anon_expr() {
entry:
%calltmp = call double @printd(double 1.200000e+01)
ret double %calltmp
}
12.000000
0
可以看到,我们成功给Kaleiscope添加了printd函数。
接下来看我们需要实现的循环语法, 使用C++代码作为注释
def printstar(n):
for i = 1, i < n, 1.0 in # for (double i = 1.0; i < n; i += 1.0)
printd(n)
同样,我们增加for和in的TOKEN
enum Token {
TOKEN_EOF = -1, // 文件结束标识符
TOKEN_DEF = -2, // 关键字def
TOKEN_EXTERN = -3, // 关键字extern
TOKEN_IDENTIFIER = -4, // 名字
TOKEN_NUMBER = -5, // 数值
TOKEN_IF = -6, // if
TOKEN_THEN = -7, // then
TOKEN_ELSE = -8, // else
TOKEN_FOR = -9, // for
TOKEN_IN = -10 // in
};
增加TOKEN的识别
// 识别字符串
if (isalpha(last_char)) {
g_identifier_str = last_char;
while (isalnum((last_char = getchar()))) {
g_identifier_str += last_char;
}
if (g_identifier_str == "def") {
return TOKEN_DEF;
} else if (g_identifier_str == "extern") {
return TOKEN_EXTERN;
} else if (g_identifier_str == "if") {
return TOKEN_IF;
} else if (g_identifier_str == "then") {
return TOKEN_THEN;
} else if (g_identifier_str == "else") {
return TOKEN_ELSE;
} else if (g_identifier_str == "for") {
return TOKEN_FOR;
} else if (g_identifier_str == "in") {
return TOKEN_IN;
} else {
return TOKEN_IDENTIFIER;
}
}
增加ForExprAST
// for in
class ForExprAST : public ExprAST {
public:
ForExprAST(const std::string& var_name, std::unique_ptr<ExprAST> start_expr,
std::unique_ptr<ExprAST> end_expr,
std::unique_ptr<ExprAST> step_expr,
std::unique_ptr<ExprAST> body_expr)
: var_name_(var_name),
start_expr_(std::move(start_expr)),
end_expr_(std::move(end_expr)),
step_expr_(std::move(step_expr)),
body_expr_(std::move(body_expr)) {}
llvm::Value* CodeGen() override;
private:
std::string var_name_;
std::unique_ptr<ExprAST> start_expr_;
std::unique_ptr<ExprAST> end_expr_;
std::unique_ptr<ExprAST> step_expr_;
std::unique_ptr<ExprAST> body_expr_;
};
添加到Primary的解析中
// forexpr ::= for var_name = start_expr, end_expr, step_expr in body_expr
std::unique_ptr<ExprAST> ParseForExpr() {
GetNextToken(); // eat for
std::string var_name = g_identifier_str;
GetNextToken(); // eat var_name
GetNextToken(); // eat =
std::unique_ptr<ExprAST> start_expr = ParseExpression();
GetNextToken(); // eat ,
std::unique_ptr<ExprAST> end_expr = ParseExpression();
GetNextToken(); // eat ,
std::unique_ptr<ExprAST> step_expr = ParseExpression();
GetNextToken(); // eat in
std::unique_ptr<ExprAST> body_expr = ParseExpression();
return std::make_unique<ForExprAST>(var_name, std::move(start_expr),
std::move(end_expr), std::move(step_expr),
std::move(body_expr));
}
// primary
// ::= identifierexpr
// ::= numberexpr
// ::= parenexpr
std::unique_ptr<ExprAST> ParsePrimary() {
switch (g_current_token) {
case TOKEN_IDENTIFIER: return ParseIdentifierExpr();
case TOKEN_NUMBER: return ParseNumberExpr();
case '(': return ParseParenExpr();
case TOKEN_IF: return ParseIfExpr();
case TOKEN_FOR: return ParseForExpr();
default: return nullptr;
}
}
开始codegen
llvm::Value* ForExprAST::CodeGen() {
// codegen start
llvm::Value* start_val = start_expr_->CodeGen();
// 获取当前function
llvm::Function* func = g_ir_builder.GetInsertBlock()->getParent();
// 保存当前的block
llvm::BasicBlock* pre_block = g_ir_builder.GetInsertBlock();
// 新增一个loop block到当前function
llvm::BasicBlock* loop_block =
llvm::BasicBlock::Create(g_llvm_context, "loop", func);
// 为当前block增加到loop_block的跳转指令
g_ir_builder.CreateBr(loop_block);
// 开始在loop_block内增加指令
g_ir_builder.SetInsertPoint(loop_block);
llvm::PHINode* var = g_ir_builder.CreatePHI(
llvm::Type::getDoubleTy(g_llvm_context), 2, var_name_.c_str());
// 如果来自pre_block的跳转,则取start_val的值
var->addIncoming(start_val, pre_block);
// 现在我们新增了一个变量var,因为可能会被后面的代码引用,所以要注册到
// g_named_values中,其可能会和函数参数重名,但我们这里为了方便不管
// 这个特殊情况,直接注册到g_named_values中,
g_named_values[var_name_] = var;
// 在loop_block中增加body的指令
body_expr_->CodeGen();
// codegen step_expr
llvm::Value* step_value = step_expr_->CodeGen();
// next_var = var + step_value
llvm::Value* next_value = g_ir_builder.CreateFAdd(var, step_value, "nextvar");
// codegen end_expr
llvm::Value* end_value = end_expr_->CodeGen();
// end_value = (end_value != 0.0)
end_value = g_ir_builder.CreateFCmpONE(
end_value, llvm::ConstantFP::get(g_llvm_context, llvm::APFloat(0.0)),
"loopcond");
// 和if/then/else一样,这里的block可能会发生变化,保存当前的block
llvm::BasicBlock* loop_end_block = g_ir_builder.GetInsertBlock();
// 创建循环结束后的block
llvm::BasicBlock* after_block =
llvm::BasicBlock::Create(g_llvm_context, "afterloop", func);
// 根据end_value选择是再来一次loop_block还是进入after_block
g_ir_builder.CreateCondBr(end_value, loop_block, after_block);
// 给after_block增加指令
g_ir_builder.SetInsertPoint(after_block);
// 如果是再次循环,取新的值
var->addIncoming(next_value, loop_end_block);
// 循环结束,避免被再次引用
g_named_values.erase(var_name_);
// return 0
return llvm::Constant::getNullValue(llvm::Type::getDoubleTy(g_llvm_context));
}
输入
extern printd(x)
def foo(x)
if x < 3 then
1
else
foo(x - 1) + foo(x - 2)
for i = 1, i < 10, 1.0 in
printd(foo(i))
输出
parsed a extern
declare double @printd(double)
parsed a function definition
define double @foo(double %x) {
entry:
%cmptmp = fcmp ult double %x, 3.000000e+00
%booltmp = uitofp i1 %cmptmp to double
%ifcond = fcmp one double %booltmp, 0.000000e+00
br i1 %ifcond, label %then, label %else
then: ; preds = %entry
br label %ifcont
else: ; preds = %entry
%subtmp = fsub double %x, 1.000000e+00
%calltmp = call double @foo(double %subtmp)
%subtmp1 = fsub double %x, 2.000000e+00
%calltmp2 = call double @foo(double %subtmp1)
%addtmp = fadd double %calltmp, %calltmp2
br label %ifcont
ifcont: ; preds = %else, %then
%iftmp = phi double [ 1.000000e+00, %then ], [ %addtmp, %else ]
ret double %iftmp
}
parsed a top level expr
define double @__anon_expr() {
entry:
br label %loop
loop: ; preds = %loop, %entry
%i = phi double [ 1.000000e+00, %entry ], [ %nextvar, %loop ]
%calltmp = call double @foo(double %i)
%calltmp1 = call double @printd(double %calltmp)
%nextvar = fadd double %i, 1.000000e+00
%cmptmp = fcmp ult double %i, 1.000000e+01
%booltmp = uitofp i1 %cmptmp to double
%loopcond = fcmp one double %booltmp, 0.000000e+00
br i1 %loopcond, label %loop, label %afterloop
afterloop: ; preds = %loop
ret double 0.000000e+00
}
1.000000
1.000000
2.000000
3.000000
5.000000
8.000000
13.000000
21.000000
34.000000
55.000000
0
成功遍历了斐波那契数列。