《用Go语言自制解释器》之第5章 宏系统
- 第5章 宏系统
- 5.1 宏系统
- 5.2 Monkey的宏系统
- 5.3 quote
- 5.4 unquote
- 5.5 宏扩展
- 5.6 扩展REPL
- 5.7 宏畅想
- 5.8 完整工程
- monkey/token/token.go
- monkey/lexer/lexer.go
- monkey/lexer/lexer_test.go
- monkey/ast/ast.go
- monkey/ast/ast_test.go
- monkey/ast/modify.go
- monkey/ast/modify_test.go
- monkey/parser/parser.go
- monkey/parser/parser_test.go
- monkey/parser/parser_tracing.go
- monkey/object/object.go
- monkey/object/object_test.go
- monkey/object/environment.go
- monkey/evaluator/builtins.go
- monkey/evaluator/evaluator.go
- monkey/evaluator/evaluator_test.go
- monkey/evaluator/quote_unquote.go
- monkey/evaluator/quote_unquote_test.go
- monkey/evaluator/macro_expansion.go
- monkey/evaluator/macro_expansion_test.go
- monkey/repl/repl.go
- monkey/main.go
- monkey/run.txt
第5章 宏系统
5.1 宏系统
与宏有关的编程语言特性,包括宏定义、访问、求值、工作。两大类宏:文本替换宏系统和语法宏系统,相当于搜索替换和代码即数据。
文本替换宏系统,仅用于文本层面,模板系统。
#define GREETING "Hello World"
int main() {
#ifdef DEBUG
printf(GREETING " Debug-Mode!\n");
#else
printf(GREETING " Production-Mode!\n");
#endif
return 0;
}
语法宏系统,将代码视为数据(代码即数据,向修改数据一样修改代码)。
源代码经过词法分析,语法分析后形成AST结构体(AST可以转换为源代码)。
所以可以将代码视为数据,可以传递、修改和生成源代码。
若语言X有语法宏系统,可以用X语言处理X语言编写的源代码。
quote停止对代码求值并将代码视为数据。
unquote“跳出”quote上下文并对其中代码求值。
两者用于控制代码的求值时机和求值方式,并将代码转换为数据。
5.2 Monkey的宏系统
quote接受一个参数并阻断对其的求值,然后返回一个对象来表示引用状态(quoted)的代码。
quote(foobar);
QUOTE(foobar)
quote(foobar+10+5);
QUOTE(((foobar + 10) + 5))
unquote将引用源代码转为非引用状态。
let quotedInfixExpression = quote(4+4);
quote(unquote(4+4)+unquote(quotedInfixExpression));
QUOTE((8 + (4 + 4)))
参数被传递给宏的主体之前不会进行求值。
let reverse = macro(a, b) { quote(unquote(b) - unquote(a)); };
reverse(2+2, 10-5);
1
let evalSecondArg = macro(a, b) { quote(unquote(b)); };
evalSecondArg(puts("not printed"), puts("printed"));
printed
5.3 quote
只在宏中使用,调用时不对参数求值,而是返回表示参数的AST节点。
// object/object.go
const (
QUOTE_OBJ = "QUOTE"
)
type Quote struct {
Node ast.Node
}
func (q *Quote) Type() ObjectType { return QUOTE_OBJ }
func (q *Quote) Inspect() string {
return "QUOTE(" + q.Node.String() + ")"
}
// evaluator/evaluator.go
func Eval(node ast.Node, env *object.Environment) object.Object {
case *ast.CallExpression:
// quote()函数调用
if node.Function.String() == "quote" {
return quote(node.Arguments[0])
}
}
// evaluator/qutoe_unquote.go
package evaluator
import (
"monkey/ast"
"monkey/object"
)
func quote(node ast.Node) object.Object {
return &object.Quote{Node: node}
}
// evaluator/qutoe_unquote_test.go
package evaluator
import (
"testing"
"monkey/object"
)
func testQuote(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
`quote(5)`,
`5`,
},
{
`quote(5+8)`,
`(5 + 8)`,
},
{
`quote(foobar)`,
`foobar`,
},
{
`quote(foobar+barfoo)`,
`(foobar + barfoo)`,
},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
quote, ok := evaluated.(*object.Quote)
if !ok {
t.Fatalf("expected *object.Quote. got=%T (%+v)", evaluated, evaluated)
}
if quote.Node == nil {
t.Fatalf("quote.Node is nil")
}
if quote.Node.String() != tt.expected {
t.Errorf("not equal. got=%q, want=%q",
quote.Node.String(), tt.expected)
}
}
}
5.4 unquote
quote是不对参数求值且保留为ast.Node(Eval解析quote()函数时,“跳过”quote()函数内部的ast.Node,不进行求值),unquote对内部内容求值。
unquote调用只是在quote调用的参数中使用,不能通过Eval递归查找unquote调用并求值。
5.4.1 遍历树
将unquote调用的结果转换为新的AST节点,并替换(修改)现有的unquote调用。
- 第一步
// ast/modify.go
package ast
//修改节点的函数
type ModifierFunc func(Node) Node
//遍历节点,并调用修改节点的函数
func Modify(node Node, modifier ModifierFunc) Node {
switch node := node.(type) {
case *Porogram:
for i, statement := range node.Statements {
node.Statements[i], _ = Modify(statement, modifier).(Statement)
}
case *ExpressionStatement:
node.Expression, _ = Modify(node.Expression, modifier).(Expression)
}
//修改表达式节点
return modifier(node)
}
// ast/modify_test.go
package ast
import (
"reflect"
"testing"
)
func TestModify(t *testing.T) {
one := func() Expression { return &IntegerLiteral{Value: 1} }
two := func() Expression { return &IntegerLiteral{Value: 2} }
turnOneIntoTwo := func(node Node) Node {
integer, ok := node.(*IntegerLiteral)
if !ok {
return node
}
if integer.Value != 1 {
return node
}
integer.Value = 2
return integer
}
tests := []struct {
input Node
expected Node
}{
{
one(),
two(),
},
{
&Program{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
&Program{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
},
}
for _, tt := range tests {
modified := Modify(tt.input, turnOneIntoTwo)
equal := reflect.DeepEqual(modified, tt.expected)
if !equal {
t.Errorf("not equal. got=%#v, want=%#v", modified, tt.expected)
}
}
}
go test ./ast
- 完成遍历
(1)中缀表达式
// ast/modify.go
func Modify(node Node, modifier ModifierFunc) Node {
// [...]
switch node := node.(type) {
case *InfixExpression:
node.Left, _ = Modify(node.Left, modifier).(Expression)
node.Right, _ = Modify(node.Right, modifier).(Expression)
}// [...]
}
// ast/modify_test.go
func TestModify(t *testing.T) {
// [...]
tests := []struct {
input Node
expected Node
}{
// [...]
{
&InfixExpression{Left: one(), Token: token.Token{Type: token.PLUS, Literal: "+"}, Right: two()},
&InfixExpression{Left: two(), Token: token.Token{Type: token.PLUS, Literal: "+"}, Right: two()},
},
{
&InfixExpression{Left: two(), Token: token.Token{Type: token.PLUS, Literal: "+"}, Right: one()},
&InfixExpression{Left: two(), Token: token.Token{Type: token.PLUS, Literal: "+"}, Right: two()},
},
}
// [...]
}
(2)前缀表达式
// ast/modify.go
func Modify(node Node, modifier ModifierFunc) Node {
// [...]
switch node := node.(type) {
case *PrefixExpression:
node.Right, _ = Modify(node.Right, modifier).(Expression)
}// [...]
}
// ast/modify_test.go
func TestModify(t *testing.T) {
// [...]
tests := []struct {
input Node
expected Node
}{
// [...]
{
&PrefixExpression{Token: token.Token{Type: token.MINUS, Literal: "-"}, Right: one()},
&PrefixExpression{Token: token.Token{Type: token.MINUS, Literal: "-"}, Right: two()},
},
}
// [...]
}
(3)索引表达式
// ast/modify.go
func Modify(node Node, modifier ModifierFunc) Node {
// [...]
switch node := node.(type) {
case *IndexExpression:
node.Left, _ = Modify(node.Left, modifier).(Expression)
node.Index, _ = Modify(node.Index, modifier).(Expression)
}// [...]
}
// ast/modify_test.go
func TestModify(t *testing.T) {
// [...]
tests := []struct {
input Node
expected Node
}{
// [...]
{
&IndexExpression{Left: one(), Index: one()},
&IndexExpression{Left: two(), Index: two()},
},
}
// [...]
}
(4)IF表达式
// ast/modify.go
func Modify(node Node, modifier ModifierFunc) Node {
// [...]
switch node := node.(type) {case *IfExpression:
node.Condition, _ = Modify(node.Condition, modifier).(Expression)
node.Consequence, _ = Modify(node.Consequence, modifier).(*BlockStatement)
if node.Alternative != nil {
node.Alternative, _ = Modify(node.Alternative, modifier).(*BlockStatement)
}
case *BlockStatement:
for i := range node.Statements {
node.Statements[i], _ = Modify(node.Statements[i], modifier).(Statement)
}
}// [...]
}
// ast/modify_test.go
func TestModify(t *testing.T) {
// [...]
tests := []struct {
input Node
expected Node
}{
// [...]
{
&IfExpression{
Condition: one(),
Consequence: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
Alternative: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
},
&IfExpression{
Condition: two(),
Consequence: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
Alternative: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
},
},
}
// [...]
}
(5)return语句
// ast/modify.go
func Modify(node Node, modifier ModifierFunc) Node {
// [...]
switch node := node.(type) {
case *ReturnStatement:
node.ReturnValue, _ = Modify(node.ReturnValue, modifier).(Expression)
}// [...]
}
// ast/modify_test.go
func TestModify(t *testing.T) {
// [...]
tests := []struct {
input Node
expected Node
}{
// [...]
{
&ReturnStatement{ReturnValue: one()},
&ReturnStatement{ReturnValue: two()},
},
}
// [...]
}
(6)let语句
// ast/modify.go
func Modify(node Node, modifier ModifierFunc) Node {
// [...]
switch node := node.(type) {
case *LetStatement:
node.Value, _ = Modify(node.Value, modifier).(Expression)
}// [...]
}
// ast/modify_test.go
func TestModify(t *testing.T) {
// [...]
tests := []struct {
input Node
expected Node
}{
// [...]
{
&LetStatement{Value: one()},
&LetStatement{Value: two()},
},
}
// [...]
}
(7)函数字面量
// ast/modify.go
func Modify(node Node, modifier ModifierFunc) Node {
// [...]
switch node := node.(type) {
case *FunctionLiteral:
for i := range node.Parameters {
node.Parameters[i], _ = Modify(node.Parameters[i], modifier).(*Identifier)
}
node.Body, _ = Modify(node.Body, modifier).(*BlockStatement)
}// [...]
}
// ast/modify_test.go
func TestModify(t *testing.T) {
// [...]
tests := []struct {
input Node
expected Node
}{
// [...]
{
&FunctionLiteral{
Parameters: []*Identifier{},
Body: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
},
&FunctionLiteral{
Parameters: []*Identifier{},
Body: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
},
},
}
// [...]
}
(8)数组字面量
// ast/modify.go
func Modify(node Node, modifier ModifierFunc) Node {
// [...]
switch node := node.(type) {
case *ArrayLiteral:
for i := range node.Elements {
node.Elements[i], _ = Modify(node.Elements[i], modifier).(Expression)
}
}// [...]
}
// ast/modify_test.go
func TestModify(t *testing.T) {
// [...]
tests := []struct {
input Node
expected Node
}{
// [...]
{
&ArrayLiteral{Elements: []Expression{one(), one()}},
&ArrayLiteral{Elements: []Expression{two(), two()}},
},
}
// [...]
}
(9)哈希字面量
// ast/modify.go
func Modify(node Node, modifier ModifierFunc) Node {
// [...]
case *HashLiteral:
newPairs := make(map[Expression]Expression)
for key, val := range node.Pairs {
newKey, _ := Modify(key, modifier).(Expression)
newVal, _ := Modify(val, modifier).(Expression)
newPairs[newKey] = newVal
}
node.Pairs = newPairs
}
// ast/modify_test.go
func TestModify(t *testing.T) {
// [...]
tests := []struct {
input Node
expected Node
}{
// [...]
}
hashLiteral := &HashLiteral{
Pairs: map[Expression]Expression{
one(): one(),
},
}
Modify(hashLiteral, turnOneIntoTwo)
for key, val := range hashLiteral.Pairs {
key, _ := key.(*IntegerLiteral)
if key.Value != 2 {
t.Errorf("vale is not %d, got=%d", 2, key.Value)
}
val, _ := val.(*IntegerLiteral)
if val.Value != 2 {
t.Errorf("vale is not %d, got=%d", 2, val.Value)
}
}
}
5.4.2 替换unquote调用
// evaluator/quote_unquote.go
package evaluator
import (
"monkey/ast"
"monkey/object"
"monkey/token"
"strconv"
)
func quote(node ast.Node, env *object.Environment) object.Object {
node = evalUnquoteCalls(node, env)
return &object.Quote{Node: node}
}
//遍历node,执行unquote()函数调用node
func evalUnquoteCalls(quoted ast.Node, env *object.Environment) ast.Node {
return ast.Modify(quoted, func(node ast.Node) ast.Node {
if isUnquoteCall(node) {
call, _ := node.(*ast.CallExpression)
unquoted := Eval(call.Arguments[0], env)
return convertObjectToASTNode(unquoted)
}
return node
})
}
//判断是否为unquote()函数调用node
func isUnquoteCall(node ast.Node) bool {
call, ok := node.(*ast.CallExpression)
if ok {
return call.Function.String() == "unquote" && len(call.Arguments) == 1
}
return false
}
//unquote()函数调用生成的obj转为ast
func convertObjectToASTNode(obj object.Object) ast.Node {
switch obj := obj.(type) {
case *object.Integer:
return &ast.IntegerLiteral{Token: token.Token{Type: token.INT, Literal: strconv.Itoa(int(obj.Value))}, Value: obj.Value}
default:
return nil
}
}
// evaluator/quote_unquote_test.go
package evaluator
import (
"monkey/object"
"testing"
)
func testQuote(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
`quote(5)`,
`5`,
},
{
`quote(5+8)`,
`(5 + 8)`,
},
{
`quote(foobar)`,
`foobar`,
},
{
`quote(foobar+barfoo)`,
`(foobar + barfoo)`,
},
{
`let foobar=8;
quote(foobar)`,
`foobar`,
},
{
`let foobar=8;
quote(unquote(foobar))`,
`8`,
},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
quote, ok := evaluated.(*object.Quote)
if !ok {
t.Fatalf("expected *object.Quote. got=%T (%+v)", evaluated, evaluated)
}
if quote.Node == nil {
t.Fatalf("quote.Node is nil")
}
if quote.Node.String() != tt.expected {
t.Errorf("not equal. got=%q, want=%q",
quote.Node.String(), tt.expected)
}
}
}
- 布尔值转为AST节点
func convertObjectToASTNode(obj object.Object) ast.Node {
switch obj := obj.(type) {
// [...]
case *object.Boolean:
var t token.Token
if obj.Value {
t = token.Token{Type: token.TRUE, Literal: "true"}
} else {
t = token.Token{Type: token.FALSE, Literal: "false"}
}
return &ast.Boolean{Token: t, Value: obj.Value}
// [...]
}
}
func TestQuote(t *testing.T) {
tests := []struct {
input string
expected string
}{
// [...]
{
`quote(unquote(true))`,
`true`,
},
{
`quote(unquote(true==false))`,
`false`,
},
}
// [...]
}
- quote-unquote-quote嵌套
// evaluator/quote_unquote.go
func convertObjectToASTNode(obj object.Object) ast.Node {
switch obj := obj.(type) {
// [...]
case *object.Quote:
return obj.Node
// [...]
}
}
// evaluator/quote_unquote_test.go
func TestQuoteUnquote(t *testing.T) {
tests := []struct {
input string
expected string
}{
// [...]
{
`quote(quote(123))`,
`quote(123)`,
},
{
`quote(unquote(quote(4+4)))`,
`(4 + 4)`,
},
}
// [...]
}
5.5 宏扩展
源代码解释步骤:
一、词法分析(字符串转换为词法单元);
二、语法分析(词法单元转换为AST);
三、宏扩展(获取AST,修改后返回AST。源代码中宏的调用求值,并替换原来的宏);
四、求值(递归求值AST中节点,逐个处理每条语句和表达式)。
宏扩展步骤:
一、遍历AST,找到所有宏定义,保存到环境变量中,然后从AST中删除宏定义;
let myMacro = marco(x, y) { quote(unquote(x) + unquote(y)); }
二、找到宏的调用,并求值;
宏调用求值阶段,宏调用的参数在宏主体中以未求值的ast.Node形式访问。
三、宏调用的结果重新插回AST。
5.5.1 macro关键字
// token/token.go
const (
// [...]
MACRO = "MACRO"
)
var keywords = map[string]TokenType{
// [...]
"macro": MACRO,
}
// lexer/lexer_test.go
func TestNextToken(t *testing.T) {
input := `macro(x,y){x+y;};
`
tests := []struct {
expectedType token.TokenType
expectedLiteral string
}{
{token.MACRO, "macro"},
{token.LPAREN, "("},
{token.IDENT, "x"},
{token.COMMA, ","},
{token.IDENT, "y"},
{token.RPAREN, ")"},
{token.LBRACE, "{"},
{token.IDENT, "x"},
{token.PLUS, "+"},
{token.IDENT, "y"},
{token.SEMICOLON, ";"},
{token.RBRACE, "}"},
{token.SEMICOLON, ";"},
{token.EOF, ""},
}
// [...]
}
go test ./lexer
5.5.2 宏字面量语法分析
// ast/ast.go
// 宏字面量
type MacroLiteral struct {
Parameters []*Identifier
Body *BlockStatement
}
func (ml *MacroLiteral) String() string {
var out bytes.Buffer
params := []string{}
for _, p := range ml.Parameters {
params = append(params, p.String())
}
out.WriteString("macro")
out.WriteString("(")
out.WriteString(strings.Join(params, ", "))
out.WriteString(") ")
out.WriteString(ml.Body.String())
return out.String()
}
// parser/parser.go
func New(l *lexer.Lexer) *Parser {
// [...]
p.registerPrefix(token.MACRO, p.parseMacroLiteral) // macro
// [...]
}
// 解析函数字面量表达式
// macro() {}
func (p *Parser) parseMacroLiteral() ast.Expression {
lit := &ast.MacroLiteral{}
if !p.expectPeek(token.LPAREN) {
return nil
}
lit.Parameters = p.parseFunctionParameters()
if !p.expectPeek(token.LBRACE) {
return nil
}
lit.Body = p.parseBlockStatement()
return lit
}
// parser/parser_test.go
func TestMacroLiteralParsing(t *testing.T) {
input := `macro(x, y) { x + y; }`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
1, len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
macro, ok := stmt.Expression.(*ast.MacroLiteral)
if !ok {
t.Fatalf("stmt.Expression is not ast.MacroLiteral. got=%T",
stmt.Expression)
}
if len(macro.Parameters) != 2 {
t.Fatalf("macro literal parameters wrong. want 2, got=%d\n",
len(macro.Parameters))
}
testLiteralExpression(t, macro.Parameters[0], "x")
testLiteralExpression(t, macro.Parameters[1], "y")
if len(macro.Body.Statements) != 1 {
t.Fatalf("macro.Body.Statements has not 1 statements. got=%d\n",
len(macro.Body.Statements))
}
bodyStmt, ok := macro.Body.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("macro body stmt is not ast.ExpressionStatement. got=%T",
macro.Body.Statements[0])
}
testInfixExpression(t, bodyStmt.Expression, "x", "+", "y")
}
go test ./parser
5.5.3 定义宏
// evaluator/macro_expansion.go
package evaluator
import (
"monkey/ast"
"monkey/object"
)
func DefineMacros(program *ast.Program, env *object.Environment) {
definitions := []int{}
for i, statement := range program.Statements {
if isMacroDefinition(statement) {
addMacro(statement, env)
definitions = append(definitions, i)
}
}
for i := len(definitions) - 1; i >= 0; i-- {
definitionIndex := definitions[i]
program.Statements = append(program.Statements[:definitionIndex],
program.Statements[definitionIndex+1:]...)
}
}
func isMacroDefinition(node ast.Statement) bool {
letStatement, ok := node.(*ast.LetStatement)
if ok {
_, ok = letStatement.Value.(*ast.MacroLiteral)
}
return ok
}
func addMacro(stmt ast.Statement, env *object.Environment) {
letStatement, _ := stmt.(*ast.LetStatement)
macroLiteral, _ := letStatement.Value.(*ast.MacroLiteral)
macro := &object.Macro{
Parameters: macroLiteral.Parameters,
Env: env,
Body: macroLiteral.Body,
}
env.Set(letStatement.Name.String(), macro)
}
// evaluator/macro_expansion_test.go
package evaluator
import (
"monkey/ast"
"monkey/lexer"
"monkey/object"
"monkey/parser"
"testing"
)
func TestDefineMacros(t *testing.T) {
input := `let number=1;
let function=fn(x,y){x+y};
let mymacro=macro(x,y){x+y;};`
env := object.NewEnvironment()
program := testParseProgram(input)
DefineMacros(program, env)
if len(program.Statements) != 2 {
t.Fatalf("Wrong number of statements. got=%d",
len(program.Statements))
}
_, ok := env.Get("number")
if ok {
t.Fatalf("number should not be defined")
}
_, ok = env.Get("function")
if ok {
t.Fatalf("function should not be defined")
}
obj, ok := env.Get("mymacro")
if !ok {
t.Fatalf("macro not in environment.")
}
macro, ok := obj.(*object.Macro)
if !ok {
t.Fatalf("macro not in environment.")
}
if len(macro.Parameters) != 2 {
t.Fatalf("Wrong number of macro parameters. got=%d",
len(macro.Parameters))
}
if macro.Parameters[0].String() != "x" {
t.Fatalf("parameter is not 'x'. got=%q",
macro.Parameters[0])
}
if macro.Parameters[1].String() != "y" {
t.Fatalf("parameter is not 'y'. got=%q",
macro.Parameters[1])
}
expectedBody := "{\n\t(x + y);\n}"
if macro.Body.String() != expectedBody {
t.Fatalf("body is not %q. got=%q",
expectedBody, macro.Body.String())
}
}
func testParseProgram(input string) *ast.Program {
l := lexer.New(input)
p := parser.New(l)
return p.ParseProgram()
}
go test ./evaluator
5.5.4 展开宏
对宏的调用求值,然后将求值结果重新插回AST中替换原始的调用表达式。
必须从宏返回*object.Quote。
// evaluator/macro_expansion.go
func ExpandMacros(program ast.Node, env *object.Environment) ast.Node {
return ast.Modify(program, func(node ast.Node) ast.Node {
if callExpression, ok := node.(*ast.CallExpression); ok {
if macro, ok := isMacroCall(callExpression, env); ok {
args := quoteArgs(callExpression)
evalEnv := extendMacroEnv(macro, args)
evaluated := Eval(macro.Body, evalEnv)
if quote, ok := evaluated.(*object.Quote); ok {
return quote.Node
}
panic("we only support returning AST-nodes from macros")
}
}
return node
})
}
func isMacroCall(exp *ast.CallExpression, env *object.Environment) (*object.Macro, bool) {
if identifier, ok := exp.Function.(*ast.Identifier); ok {
if obj, ok := env.Get(identifier.String()); ok {
if macro, ok := obj.(*object.Macro); ok {
return macro, true
}
}
}
return nil, false
}
func quoteArgs(exp *ast.CallExpression) []*object.Quote {
args := []*object.Quote{}
for _, a := range exp.Arguments {
args = append(args, &object.Quote{Node: a})
}
return args
}
func extendMacroEnv(macro *object.Macro, args []*object.Quote) *object.Environment {
extended := object.NewEnclosedEnvironment(macro.Env)
for paramIdx, param := range macro.Parameters {
extended.Set(param.String(), args[paramIdx])
}
return extended
}
// evaluator/macro_expansion_test.go
func TestExpandMacros(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
`let infixExpression=macro(){quote(1+2);};
infixExpression()`,
`(1 + 2)`,
},
{
`let reverse=macro(a,b){quote(unquote(b) - unquote(a));};
reverse(2+2,10-5);`,
`(10 - 5) - (2 + 2)`,
},
}
for _, tt := range tests {
expected := testParseProgram(tt.expected)
program := testParseProgram(tt.input)
env := object.NewEnvironment()
DefineMacros(program, env)
expanded := ExpandMacros(program, env)
if expanded.String() != expected.String() {
t.Errorf("not equal. want=%q, got=%q", expected.String(), expanded.String())
}
}
}
5.5.5 unless宏
// evaluator/macro_expansion_test.go
func TestExpandMacros(t *testing.T) {
tests := []struct {
input string
expected string
}{
// [...]
{
`let unless = macro(condition, consequence, alternative) {
quote(if(unquote(condition)) {
unquote(consequence);
} else {
unquote(alternative);
});
};
unless(10>5, puts("greater"), puts("not greater"));`,
`if(10 > 5) { puts("greater") } else { puts("not greater") }`,
},
}
// [...]
}
5.6 扩展REPL
// repl/repl.go
func Start(in io.Reader, out io.Writer) {
// [...]
macroEnv := object.NewEnvironment()
for {
// [...]
evaluator.DefineMacros(program, macroEnv)
expended := evaluator.ExpandMacros(program, macroEnv)
// [...]
}
}
let unless = macro(condition, consequence, alternative) {quote(if(unquote(condition)) { unquote(consequence);} else { unquote(alternative);});};
unless(10>5, puts("greater"), puts("not greater"));
5.7 宏畅想
编写生成代码的代码。
块语句传递给quote/unquote调用。
quote() {
let one = 1;
let two = 2;
one + two;
}
5.8 完整工程
monkey/token/token.go
package token
type TokenType string
const (
ILLEGAL = "ILLEGAL"
EOF = "EOF"
// 标识符+字面量
IDENT = "IDENT" // add, foobar, x, y, ...
INT = "INT" // 1343456
STRING = "STRING" // "foobar"
// 运算符
ASSIGN = "="
PLUS = "+"
MINUS = "-"
BANG = "!"
ASTERISK = "*"
SLASH = "/"
LT = "<"
GT = ">"
EQ = "=="
NOT_EQ = "!="
// 分隔符
COMMA = ","
SEMICOLON = ";"
COLON = ":"
LPAREN = "("
RPAREN = ")"
LBRACE = "{"
RBRACE = "}"
LBRACKET = "["
RBRACKET = "]"
// 关键字
FUNCTION = "FUNCTION"
LET = "LET"
TRUE = "TRUE"
FALSE = "FALSE"
IF = "IF"
ELSE = "ELSE"
RETURN = "RETURN"
MACRO = "MACRO"
)
type Token struct {
Type TokenType
Literal string
}
var keywords = map[string]TokenType{
"fn": FUNCTION,
"let": LET,
"true": TRUE,
"false": FALSE,
"if": IF,
"else": ELSE,
"return": RETURN,
"macro": MACRO,
}
func LookupIdent(ident string) TokenType {
if tok, ok := keywords[ident]; ok {
return tok
}
return IDENT
}
monkey/lexer/lexer.go
package lexer
import "monkey/token"
type Lexer struct {
input string
position int // 当前字符位置
readPosition int // 下一个字符位置
ch byte // 当前字符
}
func New(input string) *Lexer {
l := &Lexer{input: input}
l.readChar()
return l
}
func (l *Lexer) readChar() {
if l.readPosition >= len(l.input) {
l.ch = 0
} else {
l.ch = l.input[l.readPosition]
}
l.position = l.readPosition
l.readPosition += 1
}
func (l *Lexer) peekChar() byte {
if l.readPosition >= len(l.input) {
return 0
} else {
return l.input[l.readPosition]
}
}
func (l *Lexer) skipWhitespace() {
for l.ch == ' ' || l.ch == '\t' || l.ch == '\n' || l.ch == '\r' {
l.readChar()
}
}
func newToken(tokenType token.TokenType, ch byte) token.Token {
return token.Token{Type: tokenType, Literal: string(ch)}
}
func isLetter(ch byte) bool {
return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_'
}
func isDigit(ch byte) bool {
return '0' <= ch && ch <= '9'
}
func (l *Lexer) readIdentifier() string {
position := l.position
for isLetter(l.ch) {
l.readChar()
}
return l.input[position:l.position]
}
func (l *Lexer) readNumber() string {
position := l.position
for isDigit(l.ch) {
l.readChar()
}
return l.input[position:l.position]
}
func (l *Lexer) readString() string {
position := l.position + 1
for {
l.readChar()
if l.ch == '"' || l.ch == 0 {
break
}
}
return l.input[position:l.position]
}
func (l *Lexer) NextToken() token.Token {
var tok token.Token
l.skipWhitespace()
switch l.ch {
case '=':
if l.peekChar() == '=' {
ch := l.ch
l.readChar()
literal := string(ch) + string(l.ch)
tok = token.Token{Type: token.EQ, Literal: literal}
} else {
tok = newToken(token.ASSIGN, l.ch)
}
case '+':
tok = newToken(token.PLUS, l.ch)
case '-':
tok = newToken(token.MINUS, l.ch)
case '!':
if l.peekChar() == '=' {
ch := l.ch
l.readChar()
literal := string(ch) + string(l.ch)
tok = token.Token{Type: token.NOT_EQ, Literal: literal}
} else {
tok = newToken(token.BANG, l.ch)
}
case '/':
tok = newToken(token.SLASH, l.ch)
case '*':
tok = newToken(token.ASTERISK, l.ch)
case '<':
tok = newToken(token.LT, l.ch)
case '>':
tok = newToken(token.GT, l.ch)
case ';':
tok = newToken(token.SEMICOLON, l.ch)
case ':':
tok = newToken(token.COLON, l.ch)
case ',':
tok = newToken(token.COMMA, l.ch)
case '{':
tok = newToken(token.LBRACE, l.ch)
case '}':
tok = newToken(token.RBRACE, l.ch)
case '(':
tok = newToken(token.LPAREN, l.ch)
case ')':
tok = newToken(token.RPAREN, l.ch)
case '"':
tok.Type = token.STRING
tok.Literal = l.readString()
case '[':
tok = newToken(token.LBRACKET, l.ch)
case ']':
tok = newToken(token.RBRACKET, l.ch)
case 0:
tok.Literal = ""
tok.Type = token.EOF
default:
if isLetter(l.ch) {
tok.Literal = l.readIdentifier()
tok.Type = token.LookupIdent(tok.Literal)
return tok
} else if isDigit(l.ch) {
tok.Type = token.INT
tok.Literal = l.readNumber()
return tok
} else {
tok = newToken(token.ILLEGAL, l.ch)
}
}
l.readChar()
return tok
}
monkey/lexer/lexer_test.go
package lexer
import (
"monkey/token"
"testing"
)
func TestNextToken(t *testing.T) {
input := `let five = 5;
let ten = 10;
let add = fn(x, y) {
x + y;
};
let result = add(five, ten);
!-/*5;
5 < 10 > 5;
if (5 < 10) {
return true;
} else {
return false;
}
10 == 10;
10 != 9;
"foobar"
"foo bar"
[1, 2];
{"foo": "bar"}
macro(x,y){x+y;};
`
tests := []struct {
expectedType token.TokenType
expectedLiteral string
}{
{token.LET, "let"},
{token.IDENT, "five"},
{token.ASSIGN, "="},
{token.INT, "5"},
{token.SEMICOLON, ";"},
{token.LET, "let"},
{token.IDENT, "ten"},
{token.ASSIGN, "="},
{token.INT, "10"},
{token.SEMICOLON, ";"},
{token.LET, "let"},
{token.IDENT, "add"},
{token.ASSIGN, "="},
{token.FUNCTION, "fn"},
{token.LPAREN, "("},
{token.IDENT, "x"},
{token.COMMA, ","},
{token.IDENT, "y"},
{token.RPAREN, ")"},
{token.LBRACE, "{"},
{token.IDENT, "x"},
{token.PLUS, "+"},
{token.IDENT, "y"},
{token.SEMICOLON, ";"},
{token.RBRACE, "}"},
{token.SEMICOLON, ";"},
{token.LET, "let"},
{token.IDENT, "result"},
{token.ASSIGN, "="},
{token.IDENT, "add"},
{token.LPAREN, "("},
{token.IDENT, "five"},
{token.COMMA, ","},
{token.IDENT, "ten"},
{token.RPAREN, ")"},
{token.SEMICOLON, ";"},
{token.BANG, "!"},
{token.MINUS, "-"},
{token.SLASH, "/"},
{token.ASTERISK, "*"},
{token.INT, "5"},
{token.SEMICOLON, ";"},
{token.INT, "5"},
{token.LT, "<"},
{token.INT, "10"},
{token.GT, ">"},
{token.INT, "5"},
{token.SEMICOLON, ";"},
{token.IF, "if"},
{token.LPAREN, "("},
{token.INT, "5"},
{token.LT, "<"},
{token.INT, "10"},
{token.RPAREN, ")"},
{token.LBRACE, "{"},
{token.RETURN, "return"},
{token.TRUE, "true"},
{token.SEMICOLON, ";"},
{token.RBRACE, "}"},
{token.ELSE, "else"},
{token.LBRACE, "{"},
{token.RETURN, "return"},
{token.FALSE, "false"},
{token.SEMICOLON, ";"},
{token.RBRACE, "}"},
{token.INT, "10"},
{token.EQ, "=="},
{token.INT, "10"},
{token.SEMICOLON, ";"},
{token.INT, "10"},
{token.NOT_EQ, "!="},
{token.INT, "9"},
{token.SEMICOLON, ";"},
{token.STRING, "foobar"},
{token.STRING, "foo bar"},
{token.LBRACKET, "["},
{token.INT, "1"},
{token.COMMA, ","},
{token.INT, "2"},
{token.RBRACKET, "]"},
{token.SEMICOLON, ";"},
{token.LBRACE, "{"},
{token.STRING, "foo"},
{token.COLON, ":"},
{token.STRING, "bar"},
{token.RBRACE, "}"},
{token.MACRO, "macro"},
{token.LPAREN, "("},
{token.IDENT, "x"},
{token.COMMA, ","},
{token.IDENT, "y"},
{token.RPAREN, ")"},
{token.LBRACE, "{"},
{token.IDENT, "x"},
{token.PLUS, "+"},
{token.IDENT, "y"},
{token.SEMICOLON, ";"},
{token.RBRACE, "}"},
{token.SEMICOLON, ";"},
{token.EOF, ""},
}
l := New(input)
for i, tt := range tests {
tok := l.NextToken()
if tok.Type != tt.expectedType {
t.Fatalf("tests[%d] - tokentype wrong. expected=%q, got=%q",
i, tt.expectedType, tok.Type)
}
if tok.Literal != tt.expectedLiteral {
t.Fatalf("tests[%d] - literal wrong. expected=%q, got=%q",
i, tt.expectedLiteral, tok.Literal)
}
}
}
monkey/ast/ast.go
package ast
import (
"bytes"
"monkey/token"
"strings"
)
// 基础节点接口
type Node interface {
String() string
}
// 语句
type Statement interface {
Node
}
// 表达式
type Expression interface {
Node
}
// 程序
type Program struct {
Statements []Statement
}
func (p *Program) String() string {
var out bytes.Buffer
for _, s := range p.Statements {
out.WriteString(s.String())
}
return out.String()
}
// let语句
type LetStatement struct {
Name *Identifier // 标识符
Value Expression // 右侧表达式
}
func (ls *LetStatement) String() string {
var out bytes.Buffer
out.WriteString("let ")
out.WriteString(ls.Name.String())
out.WriteString(" = ")
out.WriteString(ls.Value.String())
out.WriteString(";")
out.WriteString("\n")
return out.String()
}
// return语句
type ReturnStatement struct {
ReturnValue Expression //return右边表达式
}
func (rs *ReturnStatement) String() string {
var out bytes.Buffer
out.WriteString("return ")
out.WriteString(rs.ReturnValue.String())
out.WriteString(";")
out.WriteString("\n")
return out.String()
}
// expression语句
type ExpressionStatement struct {
Expression Expression
}
func (es *ExpressionStatement) String() string {
return es.Expression.String() + ";" + "\n"
}
// block语句
type BlockStatement struct {
Statements []Statement
}
func (bs *BlockStatement) String() string {
var out bytes.Buffer
out.WriteString("{")
out.WriteString("\n")
for _, s := range bs.Statements {
out.WriteString("\t" + s.String())
}
out.WriteString("}")
return out.String()
}
// 标识符
type Identifier struct {
Token token.Token // 词法单元
}
func (i *Identifier) String() string { return i.Token.Literal }
// 布尔字面量
type Boolean struct {
Token token.Token
Value bool
}
func (b *Boolean) String() string { return b.Token.Literal }
// 整数字面量
type IntegerLiteral struct {
Token token.Token
Value int64
}
func (il *IntegerLiteral) String() string { return il.Token.Literal }
// 前缀表达式
type PrefixExpression struct {
Token token.Token // The prefix token, e.g. !
Right Expression
}
func (pe *PrefixExpression) String() string {
var out bytes.Buffer
out.WriteString("(")
out.WriteString(pe.Token.Literal)
out.WriteString(pe.Right.String())
out.WriteString(")")
return out.String()
}
// 中缀表达式
type InfixExpression struct {
Token token.Token // The operator token, e.g. +
Left Expression
Right Expression
}
func (ie *InfixExpression) String() string {
var out bytes.Buffer
out.WriteString("(")
out.WriteString(ie.Left.String())
out.WriteString(" " + ie.Token.Literal + " ")
out.WriteString(ie.Right.String())
out.WriteString(")")
return out.String()
}
// if表达式
type IfExpression struct {
Condition Expression
Consequence *BlockStatement
Alternative *BlockStatement
}
func (ie *IfExpression) String() string {
var out bytes.Buffer
out.WriteString("if")
out.WriteString(ie.Condition.String())
out.WriteString(" ")
out.WriteString(ie.Consequence.String())
if ie.Alternative != nil {
out.WriteString(" else ")
out.WriteString(ie.Alternative.String())
}
return out.String()
}
// 函数字面量
type FunctionLiteral struct {
Parameters []*Identifier
Body *BlockStatement
}
func (fl *FunctionLiteral) String() string {
var out bytes.Buffer
params := []string{}
for _, p := range fl.Parameters {
params = append(params, p.String())
}
out.WriteString("fn")
out.WriteString("(")
out.WriteString(strings.Join(params, ", "))
out.WriteString(") ")
out.WriteString(fl.Body.String())
return out.String()
}
// 调用表达式
type CallExpression struct {
Function Expression // 标识符或函数字面量
Arguments []Expression
}
func (ce *CallExpression) String() string {
var out bytes.Buffer
args := []string{}
for _, a := range ce.Arguments {
args = append(args, a.String())
}
out.WriteString(ce.Function.String())
out.WriteString("(")
out.WriteString(strings.Join(args, ", "))
out.WriteString(")")
return out.String()
}
// 字符串字面量
type StringLiteral struct {
Token token.Token
}
func (sl *StringLiteral) String() string { return sl.Token.Literal }
// 数组字面量
type ArrayLiteral struct {
Elements []Expression
}
func (al *ArrayLiteral) String() string {
var out bytes.Buffer
elements := []string{}
for _, el := range al.Elements {
elements = append(elements, el.String())
}
out.WriteString("[")
out.WriteString(strings.Join(elements, ", "))
out.WriteString("]")
return out.String()
}
// 索引表达式
type IndexExpression struct {
Left Expression
Index Expression
}
func (ie *IndexExpression) String() string {
var out bytes.Buffer
out.WriteString("(")
out.WriteString(ie.Left.String())
out.WriteString("[")
out.WriteString(ie.Index.String())
out.WriteString("])")
return out.String()
}
// 哈希字面量
type HashLiteral struct {
Pairs map[Expression]Expression
}
func (hl *HashLiteral) String() string {
var out bytes.Buffer
pairs := []string{}
for key, value := range hl.Pairs {
pairs = append(pairs, key.String()+" : "+value.String())
}
out.WriteString("{")
out.WriteString(strings.Join(pairs, ", "))
out.WriteString("}")
return out.String()
}
// 宏字面量
type MacroLiteral struct {
Parameters []*Identifier
Body *BlockStatement
}
func (ml *MacroLiteral) String() string {
var out bytes.Buffer
params := []string{}
for _, p := range ml.Parameters {
params = append(params, p.String())
}
out.WriteString("macro")
out.WriteString("(")
out.WriteString(strings.Join(params, ", "))
out.WriteString(") ")
out.WriteString(ml.Body.String())
return out.String()
}
monkey/ast/ast_test.go
package ast
import (
"monkey/token"
"testing"
)
func TestString(t *testing.T) {
program := &Program{
Statements: []Statement{
&LetStatement{
Name: &Identifier{Token: token.Token{Type: token.IDENT, Literal: "myVar"}},
Value: &Identifier{Token: token.Token{Type: token.IDENT, Literal: "anotherVar"}},
},
},
}
if program.String() != "let myVar = anotherVar;\n" {
t.Errorf("program.String() wrong. got=%q", program.String())
}
}
monkey/ast/modify.go
package ast
//修改节点的函数
type ModifierFunc func(Node) Node
//遍历节点,并调用修改节点的函数
func Modify(node Node, modifier ModifierFunc) Node {
switch node := node.(type) {
case *Program:
for i, statement := range node.Statements {
node.Statements[i], _ = Modify(statement, modifier).(Statement)
}
case *ExpressionStatement:
node.Expression, _ = Modify(node.Expression, modifier).(Expression)
case *InfixExpression:
node.Left, _ = Modify(node.Left, modifier).(Expression)
node.Right, _ = Modify(node.Right, modifier).(Expression)
case *PrefixExpression:
node.Right, _ = Modify(node.Right, modifier).(Expression)
case *IndexExpression:
node.Left, _ = Modify(node.Left, modifier).(Expression)
node.Index, _ = Modify(node.Index, modifier).(Expression)
case *IfExpression:
node.Condition, _ = Modify(node.Condition, modifier).(Expression)
node.Consequence, _ = Modify(node.Consequence, modifier).(*BlockStatement)
if node.Alternative != nil {
node.Alternative, _ = Modify(node.Alternative, modifier).(*BlockStatement)
}
case *BlockStatement:
for i := range node.Statements {
node.Statements[i], _ = Modify(node.Statements[i], modifier).(Statement)
}
case *ReturnStatement:
node.ReturnValue, _ = Modify(node.ReturnValue, modifier).(Expression)
case *LetStatement:
node.Value, _ = Modify(node.Value, modifier).(Expression)
case *FunctionLiteral:
for i := range node.Parameters {
node.Parameters[i], _ = Modify(node.Parameters[i], modifier).(*Identifier)
}
node.Body, _ = Modify(node.Body, modifier).(*BlockStatement)
case *ArrayLiteral:
for i := range node.Elements {
node.Elements[i], _ = Modify(node.Elements[i], modifier).(Expression)
}
case *HashLiteral:
newPairs := make(map[Expression]Expression)
for key, val := range node.Pairs {
newKey, _ := Modify(key, modifier).(Expression)
newVal, _ := Modify(val, modifier).(Expression)
newPairs[newKey] = newVal
}
node.Pairs = newPairs
}
//修改表达式节点
return modifier(node)
}
monkey/ast/modify_test.go
package ast
import (
"monkey/token"
"reflect"
"testing"
)
func TestModify(t *testing.T) {
one := func() Expression { return &IntegerLiteral{Value: 1} }
two := func() Expression { return &IntegerLiteral{Value: 2} }
turnOneIntoTwo := func(node Node) Node {
integer, ok := node.(*IntegerLiteral)
if !ok {
return node
}
if integer.Value != 1 {
return node
}
integer.Value = 2
return integer
}
tests := []struct {
input Node
expected Node
}{
{
one(),
two(),
},
{
&Program{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
&Program{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
},
{
&InfixExpression{Left: one(), Token: token.Token{Type: token.PLUS, Literal: "+"}, Right: two()},
&InfixExpression{Left: two(), Token: token.Token{Type: token.PLUS, Literal: "+"}, Right: two()},
},
{
&InfixExpression{Left: two(), Token: token.Token{Type: token.PLUS, Literal: "+"}, Right: one()},
&InfixExpression{Left: two(), Token: token.Token{Type: token.PLUS, Literal: "+"}, Right: two()},
},
{
&PrefixExpression{Token: token.Token{Type: token.MINUS, Literal: "-"}, Right: one()},
&PrefixExpression{Token: token.Token{Type: token.MINUS, Literal: "-"}, Right: two()},
},
{
&IndexExpression{Left: one(), Index: one()},
&IndexExpression{Left: two(), Index: two()},
},
{
&IfExpression{
Condition: one(),
Consequence: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
Alternative: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
},
&IfExpression{
Condition: two(),
Consequence: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
Alternative: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
},
},
{
&ReturnStatement{ReturnValue: one()},
&ReturnStatement{ReturnValue: two()},
},
{
&LetStatement{Value: one()},
&LetStatement{Value: two()},
},
{
&FunctionLiteral{
Parameters: []*Identifier{},
Body: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: one()},
},
},
},
&FunctionLiteral{
Parameters: []*Identifier{},
Body: &BlockStatement{
Statements: []Statement{
&ExpressionStatement{Expression: two()},
},
},
},
},
{
&ArrayLiteral{Elements: []Expression{one(), one()}},
&ArrayLiteral{Elements: []Expression{two(), two()}},
},
/*
{
&HashLiteral{
Pairs: map[Expression]Expression{
one(): one(),
},
},
&HashLiteral{
Pairs: map[Expression]Expression{
two(): two(),
},
},
},
*/
}
for _, tt := range tests {
modified := Modify(tt.input, turnOneIntoTwo)
equal := reflect.DeepEqual(modified, tt.expected)
if !equal {
t.Errorf("not equal. got=%#v, want=%#v", modified, tt.expected)
}
}
hashLiteral := &HashLiteral{
Pairs: map[Expression]Expression{
one(): one(),
},
}
Modify(hashLiteral, turnOneIntoTwo)
for key, val := range hashLiteral.Pairs {
key, _ := key.(*IntegerLiteral)
if key.Value != 2 {
t.Errorf("vale is not %d, got=%d", 2, key.Value)
}
val, _ := val.(*IntegerLiteral)
if val.Value != 2 {
t.Errorf("vale is not %d, got=%d", 2, val.Value)
}
}
}
monkey/parser/parser.go
package parser
import (
"fmt"
"monkey/ast"
"monkey/lexer"
"monkey/token"
"strconv"
)
//运算符优先级
const (
_ int = iota
LOWEST
EQUALS // == !=
LESSGREATER // > or <
SUM // + -
PRODUCT // * /
PREFIX // -X or !X
CALL // myFunction(X)
INDEX // array[index]
)
//词法单元运算符优先级
//用于中缀表达式
var precedences = map[token.TokenType]int{
token.EQ: EQUALS, // ==
token.NOT_EQ: EQUALS, // !=
token.LT: LESSGREATER, // <
token.GT: LESSGREATER, // >
token.PLUS: SUM, // +
token.MINUS: SUM, // -
token.SLASH: PRODUCT, // /
token.ASTERISK: PRODUCT, // *
token.LPAREN: CALL, // (
token.LBRACKET: INDEX, // [
}
type (
prefixParseFn func() ast.Expression //前缀解析函数
infixParseFn func(ast.Expression) ast.Expression //中缀解析函数
)
type Parser struct {
l *lexer.Lexer
errors []string
curToken token.Token
peekToken token.Token
prefixParseFns map[token.TokenType]prefixParseFn
infixParseFns map[token.TokenType]infixParseFn
}
func New(l *lexer.Lexer) *Parser {
p := &Parser{
l: l,
errors: []string{},
}
//注册前缀解析函数
p.prefixParseFns = make(map[token.TokenType]prefixParseFn)
p.registerPrefix(token.IDENT, p.parseIdentifier)
p.registerPrefix(token.INT, p.parseIntegerLiteral)
p.registerPrefix(token.STRING, p.parseStringLiteral)
p.registerPrefix(token.BANG, p.parsePrefixExpression) // !
p.registerPrefix(token.MINUS, p.parsePrefixExpression) // -
p.registerPrefix(token.TRUE, p.parseBoolean)
p.registerPrefix(token.FALSE, p.parseBoolean)
p.registerPrefix(token.LPAREN, p.parseGroupedExpression) // (,分组表达式(1 + 2) * 3
p.registerPrefix(token.IF, p.parseIfExpression)
p.registerPrefix(token.FUNCTION, p.parseFunctionLiteral) // fn
p.registerPrefix(token.LBRACKET, p.parseArrayLiteral) // [,数组字面量[1+2, 3*4]
p.registerPrefix(token.LBRACE, p.parseHashLiteral) // {,哈希字面量{"one" : 1 + 2, "two" : 2}
p.registerPrefix(token.MACRO, p.parseMacroLiteral) // macro
//注册中缀解析函数
p.infixParseFns = make(map[token.TokenType]infixParseFn)
p.registerInfix(token.PLUS, p.parseInfixExpression)
p.registerInfix(token.MINUS, p.parseInfixExpression)
p.registerInfix(token.SLASH, p.parseInfixExpression) // /
p.registerInfix(token.ASTERISK, p.parseInfixExpression) // *
p.registerInfix(token.EQ, p.parseInfixExpression)
p.registerInfix(token.NOT_EQ, p.parseInfixExpression)
p.registerInfix(token.LT, p.parseInfixExpression)
p.registerInfix(token.GT, p.parseInfixExpression)
p.registerInfix(token.LPAREN, p.parseCallExpression) // (,函数调用表达式add(2, 3)
p.registerInfix(token.LBRACKET, p.parseIndexExpression) // [,索引表达式a[0]
// 读取当前词法单元和下一个词法单元
p.nextToken()
p.nextToken()
return p
}
func (p *Parser) nextToken() {
p.curToken = p.peekToken
p.peekToken = p.l.NextToken()
}
func (p *Parser) curTokenIs(t token.TokenType) bool {
return p.curToken.Type == t
}
func (p *Parser) peekTokenIs(t token.TokenType) bool {
return p.peekToken.Type == t
}
// 判断下一个词法单元是否是期望的词法单元,是则跳过当前词法单元
func (p *Parser) expectPeek(t token.TokenType) bool {
if p.peekTokenIs(t) {
p.nextToken()
return true
} else {
p.peekError(t)
return false
}
}
func (p *Parser) Errors() []string {
return p.errors
}
func (p *Parser) peekError(t token.TokenType) {
msg := fmt.Sprintf("expected next token to be %s, got %s instead",
t, p.peekToken.Type)
p.errors = append(p.errors, msg)
}
// 解析到词法单元未注册前缀解析函数时,记录错误
func (p *Parser) noPrefixParseFnError(t token.TokenType) {
msg := fmt.Sprintf("no prefix parse function for %s found", t)
p.errors = append(p.errors, msg)
}
// 遍历语句解析程序
func (p *Parser) ParseProgram() *ast.Program {
program := &ast.Program{}
program.Statements = []ast.Statement{}
for !p.curTokenIs(token.EOF) {
stmt := p.parseStatement()
if stmt != nil {
program.Statements = append(program.Statements, stmt)
}
p.nextToken()
}
return program
}
// 解析语句
func (p *Parser) parseStatement() ast.Statement {
switch p.curToken.Type {
case token.SEMICOLON: //空语句;
return nil
case token.LET: //let语句
return p.parseLetStatement()
case token.RETURN: //return语句
return p.parseReturnStatement()
default: //expression语句
return p.parseExpressionStatement()
}
}
// 解析let语句(末尾可以无分号;)
func (p *Parser) parseLetStatement() *ast.LetStatement {
stmt := &ast.LetStatement{}
if !p.expectPeek(token.IDENT) {
return nil
}
stmt.Name = &ast.Identifier{Token: p.curToken}
if !p.expectPeek(token.ASSIGN) {
return nil
}
p.nextToken()
stmt.Value = p.parseExpression(LOWEST)
if p.peekTokenIs(token.SEMICOLON) { //下一个词法单元为分号;,则跳过当前词法单元
p.nextToken()
}
return stmt
}
// 解析return语句(末尾可以无分号;)
func (p *Parser) parseReturnStatement() *ast.ReturnStatement {
stmt := &ast.ReturnStatement{}
p.nextToken()
stmt.ReturnValue = p.parseExpression(LOWEST)
if p.peekTokenIs(token.SEMICOLON) {
p.nextToken()
}
return stmt
}
// 解析expression语句(末尾可以无分号;)
func (p *Parser) parseExpressionStatement() *ast.ExpressionStatement {
//defer untrace(trace("parseExpressionStatement"))
stmt := &ast.ExpressionStatement{}
stmt.Expression = p.parseExpression(LOWEST)
if p.peekTokenIs(token.SEMICOLON) {
p.nextToken()
}
return stmt
}
// 解析表达式
func (p *Parser) parseExpression(precedence int) ast.Expression {
//defer untrace(trace("parseExpression"))
prefix := p.prefixParseFns[p.curToken.Type]
if prefix == nil {
p.noPrefixParseFnError(p.curToken.Type)
return nil
}
leftExp := prefix() //调用前缀解析函数
// 下一个词法单元不是表达式末尾分号;,并且传入运算符优先级小于下一个运算符优先级时
// 递归调用parseExpression,生成AST
for !p.peekTokenIs(token.SEMICOLON) && precedence < p.peekPrecedence() {
infix := p.infixParseFns[p.peekToken.Type]
if infix == nil {
return leftExp
}
p.nextToken()
leftExp = infix(leftExp) //调用中缀解析函数,leftExp作为参数传入
}
return leftExp
}
func (p *Parser) peekPrecedence() int {
if p, ok := precedences[p.peekToken.Type]; ok {
return p
}
return LOWEST
}
func (p *Parser) curPrecedence() int {
if p, ok := precedences[p.curToken.Type]; ok {
return p
}
return LOWEST
}
// 解析标识符
func (p *Parser) parseIdentifier() ast.Expression {
//defer untrace(trace("parseIdentifier"))
return &ast.Identifier{Token: p.curToken}
}
// 解析整形字面量
func (p *Parser) parseIntegerLiteral() ast.Expression {
//defer untrace(trace("parseIntegerLiteral"))
lit := &ast.IntegerLiteral{Token: p.curToken}
value, err := strconv.ParseInt(p.curToken.Literal, 0, 64)
if err != nil {
msg := fmt.Sprintf("could not parse %q as integer", p.curToken.Literal)
p.errors = append(p.errors, msg)
return nil
}
lit.Value = value
return lit
}
// 解析字符串字面量
func (p *Parser) parseStringLiteral() ast.Expression {
return &ast.StringLiteral{Token: p.curToken}
}
// 解析前缀表达式
func (p *Parser) parsePrefixExpression() ast.Expression {
//defer untrace(trace("parsePrefixExpression"))
expression := &ast.PrefixExpression{
Token: p.curToken,
}
p.nextToken()
// 传入极高运算符优先级PREFIX
// 确保前缀表达式(Token expression)完整解析
expression.Right = p.parseExpression(PREFIX)
return expression
}
// 解析中缀表达式,传入左侧表达式
func (p *Parser) parseInfixExpression(left ast.Expression) ast.Expression {
//defer untrace(trace("parseInfixExpression"))
expression := &ast.InfixExpression{
Token: p.curToken,
Left: left,
}
precedence := p.curPrecedence()
p.nextToken()
expression.Right = p.parseExpression(precedence)
return expression
}
func (p *Parser) parseBoolean() ast.Expression {
return &ast.Boolean{Token: p.curToken, Value: p.curTokenIs(token.TRUE)}
}
// 解析分组表达式
func (p *Parser) parseGroupedExpression() ast.Expression {
p.nextToken()
exp := p.parseExpression(LOWEST)
if !p.expectPeek(token.RPAREN) {
return nil
}
return exp
}
// 解析if表达式
func (p *Parser) parseIfExpression() ast.Expression {
expression := &ast.IfExpression{}
if !p.expectPeek(token.LPAREN) {
return nil
}
p.nextToken()
expression.Condition = p.parseExpression(LOWEST)
if !p.expectPeek(token.RPAREN) {
return nil
}
if !p.expectPeek(token.LBRACE) {
return nil
}
expression.Consequence = p.parseBlockStatement()
if p.peekTokenIs(token.ELSE) {
p.nextToken()
if !p.expectPeek(token.LBRACE) {
return nil
}
expression.Alternative = p.parseBlockStatement()
}
return expression
}
// 解析block语句
func (p *Parser) parseBlockStatement() *ast.BlockStatement {
block := &ast.BlockStatement{}
block.Statements = []ast.Statement{}
p.nextToken()
for !p.curTokenIs(token.RBRACE) && !p.curTokenIs(token.EOF) {
stmt := p.parseStatement()
if stmt != nil {
block.Statements = append(block.Statements, stmt)
}
p.nextToken()
}
return block
}
// 解析函数字面量表达式
// fn() {}
func (p *Parser) parseFunctionLiteral() ast.Expression {
lit := &ast.FunctionLiteral{}
if !p.expectPeek(token.LPAREN) {
return nil
}
lit.Parameters = p.parseFunctionParameters()
if !p.expectPeek(token.LBRACE) {
return nil
}
lit.Body = p.parseBlockStatement()
return lit
}
// 解析函数字面量表达式,内部参数标识符a, b, c等
// fn(a, b, c) {}
func (p *Parser) parseFunctionParameters() []*ast.Identifier {
identifiers := []*ast.Identifier{}
if p.peekTokenIs(token.RPAREN) {
p.nextToken()
return identifiers
}
p.nextToken()
ident := &ast.Identifier{Token: p.curToken}
identifiers = append(identifiers, ident)
for p.peekTokenIs(token.COMMA) {
p.nextToken()
p.nextToken()
ident := &ast.Identifier{Token: p.curToken}
identifiers = append(identifiers, ident)
}
if !p.expectPeek(token.RPAREN) {
return nil
}
return identifiers
}
// 解析调用函数表达式
// add(2, 3)
func (p *Parser) parseCallExpression(function ast.Expression) ast.Expression {
exp := &ast.CallExpression{Function: function}
exp.Arguments = p.parseExpressionList(token.RPAREN)
return exp
}
// 解析表达式列表
// add(2+3, minute(5, 3))
// [1+2, 3]
func (p *Parser) parseExpressionList(end token.TokenType) []ast.Expression {
list := []ast.Expression{}
if p.peekTokenIs(end) {
p.nextToken()
return list
}
p.nextToken()
list = append(list, p.parseExpression(LOWEST))
for p.peekTokenIs(token.COMMA) {
p.nextToken()
p.nextToken()
list = append(list, p.parseExpression(LOWEST))
}
if !p.expectPeek(end) {
return nil
}
return list
}
// 解析数组字面量表达式
func (p *Parser) parseArrayLiteral() ast.Expression {
array := &ast.ArrayLiteral{}
array.Elements = p.parseExpressionList(token.RBRACKET)
return array
}
// 解析索引表达式
func (p *Parser) parseIndexExpression(left ast.Expression) ast.Expression {
exp := &ast.IndexExpression{Left: left}
p.nextToken()
exp.Index = p.parseExpression(LOWEST)
if !p.expectPeek(token.RBRACKET) {
return nil
}
return exp
}
// 解析哈希字面量表达式
func (p *Parser) parseHashLiteral() ast.Expression {
hash := &ast.HashLiteral{}
hash.Pairs = make(map[ast.Expression]ast.Expression)
for !p.peekTokenIs(token.RBRACE) {
p.nextToken()
key := p.parseExpression(LOWEST)
if !p.expectPeek(token.COLON) {
return nil
}
p.nextToken()
value := p.parseExpression(LOWEST)
hash.Pairs[key] = value
if !p.peekTokenIs(token.RBRACE) && !p.expectPeek(token.COMMA) {
return nil
}
}
if !p.expectPeek(token.RBRACE) {
return nil
}
return hash
}
// 解析函数字面量表达式
// macro() {}
func (p *Parser) parseMacroLiteral() ast.Expression {
lit := &ast.MacroLiteral{}
if !p.expectPeek(token.LPAREN) {
return nil
}
lit.Parameters = p.parseFunctionParameters()
if !p.expectPeek(token.LBRACE) {
return nil
}
lit.Body = p.parseBlockStatement()
return lit
}
func (p *Parser) registerPrefix(tokenType token.TokenType, fn prefixParseFn) {
p.prefixParseFns[tokenType] = fn
}
func (p *Parser) registerInfix(tokenType token.TokenType, fn infixParseFn) {
p.infixParseFns[tokenType] = fn
}
monkey/parser/parser_test.go
package parser
import (
"fmt"
"monkey/ast"
"monkey/lexer"
"testing"
)
func checkParserErrors(t *testing.T, p *Parser) {
errors := p.Errors()
if len(errors) == 0 {
return
}
t.Errorf("parser has %d errors", len(errors))
for _, msg := range errors {
t.Errorf("parser error: %q", msg)
}
t.FailNow()
}
func testLetStatement(t *testing.T, s ast.Statement, name string) bool {
letStmt, ok := s.(*ast.LetStatement)
if !ok {
t.Errorf("s not *ast.LetStatement. got=%T", s)
return false
}
if letStmt.Name.String() != name {
t.Errorf("letStmt.Name.String() not '%s'. got=%s", name, letStmt.Name.String())
return false
}
if letStmt.Name.Token.Literal != name {
t.Errorf("letStmt.Name.Token.Literal not '%s'. got=%s", name, letStmt.Name.Token.Literal)
return false
}
return true
}
func testIntegerLiteral(t *testing.T, il ast.Expression, value int64) bool {
integ, ok := il.(*ast.IntegerLiteral)
if !ok {
t.Errorf("il not *ast.IntegerLiteral. got=%T", il)
return false
}
if integ.Value != value {
t.Errorf("integ.Value not %d. got=%d", value, integ.Value)
return false
}
if integ.String() != fmt.Sprintf("%d", value) {
t.Errorf("integ.String() not %d. got=%s", value, integ.String())
return false
}
return true
}
func testBooleanLiteral(t *testing.T, exp ast.Expression, value bool) bool {
bo, ok := exp.(*ast.Boolean)
if !ok {
t.Errorf("exp not *ast.Boolean. got=%T", exp)
return false
}
if bo.Value != value {
t.Errorf("bo.Value not %t. got=%t", value, bo.Value)
return false
}
if bo.String() != fmt.Sprintf("%t", value) {
t.Errorf("bo.String() not %t. got=%s", value, bo.String())
return false
}
return true
}
func testIdentifier(t *testing.T, exp ast.Expression, value string) bool {
ident, ok := exp.(*ast.Identifier)
if !ok {
t.Errorf("exp not *ast.Identifier. got=%T", exp)
return false
}
if ident.String() != value {
t.Errorf("ident.String() not %s. got=%s", value,
ident.String())
return false
}
return true
}
func testLiteralExpression(
t *testing.T,
exp ast.Expression,
expected interface{},
) bool {
switch v := expected.(type) {
case int:
return testIntegerLiteral(t, exp, int64(v))
case int64:
return testIntegerLiteral(t, exp, v)
case string:
return testIdentifier(t, exp, v)
case bool:
return testBooleanLiteral(t, exp, v)
}
t.Errorf("type of exp not handled. got=%T", exp)
return false
}
func testInfixExpression(t *testing.T, exp ast.Expression, left interface{},
operator string, right interface{}) bool {
opExp, ok := exp.(*ast.InfixExpression)
if !ok {
t.Errorf("exp is not ast.InfixExpression. got=%T(%s)", exp, exp)
return false
}
if !testLiteralExpression(t, opExp.Left, left) {
return false
}
if opExp.Token.Literal != operator {
t.Errorf("exp.Token.Literal is not '%s'. got=%q", operator, opExp.Token.Literal)
return false
}
if !testLiteralExpression(t, opExp.Right, right) {
return false
}
return true
}
func TestLetStatements(t *testing.T) {
tests := []struct {
input string
expectedIdentifier string
expectedValue interface{}
}{
{"let x = 5;", "x", 5},
{"let y = true;", "y", true},
{"let foobar = y;", "foobar", "y"},
}
for _, tt := range tests {
l := lexer.New(tt.input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Statements does not contain 1 statements. got=%d",
len(program.Statements))
}
stmt := program.Statements[0]
if !testLetStatement(t, stmt, tt.expectedIdentifier) {
return
}
val := stmt.(*ast.LetStatement).Value
if !testLiteralExpression(t, val, tt.expectedValue) {
return
}
}
}
func TestReturnStatements(t *testing.T) {
tests := []struct {
input string
expectedValue interface{}
}{
{"return 5;", 5},
{"return true;", true},
{"return foobar;", "foobar"},
}
for _, tt := range tests {
l := lexer.New(tt.input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Statements does not contain 1 statements. got=%d",
len(program.Statements))
}
stmt := program.Statements[0]
returnStmt, ok := stmt.(*ast.ReturnStatement)
if !ok {
t.Fatalf("stmt not *ast.ReturnStatement. got=%T", stmt)
}
if testLiteralExpression(t, returnStmt.ReturnValue, tt.expectedValue) {
return
}
}
}
func TestIdentifierExpression(t *testing.T) {
input := "foobar;"
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program has not enough statements. got=%d",
len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
ident, ok := stmt.Expression.(*ast.Identifier)
if !ok {
t.Fatalf("exp not *ast.Identifier. got=%T", stmt.Expression)
}
if ident.String() != "foobar" {
t.Errorf("ident.String() not %s. got=%s", "foobar",
ident.String())
}
}
func TestIntegerLiteralExpression(t *testing.T) {
input := "5;"
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program has not enough statements. got=%d",
len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
literal, ok := stmt.Expression.(*ast.IntegerLiteral)
if !ok {
t.Fatalf("exp not *ast.IntegerLiteral. got=%T", stmt.Expression)
}
if literal.Value != 5 {
t.Errorf("literal.Value not %d. got=%d", 5, literal.Value)
}
if literal.String() != "5" {
t.Errorf("literal.String() not %s. got=%s", "5",
literal.String())
}
}
func TestParsingPrefixExpressions(t *testing.T) {
prefixTests := []struct {
input string
operator string
value interface{}
}{
{"!5;", "!", 5},
{"-15;", "-", 15},
{"!foobar;", "!", "foobar"},
{"-foobar;", "-", "foobar"},
{"!true;", "!", true},
{"!false;", "!", false},
}
for _, tt := range prefixTests {
l := lexer.New(tt.input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
1, len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
exp, ok := stmt.Expression.(*ast.PrefixExpression)
if !ok {
t.Fatalf("stmt is not ast.PrefixExpression. got=%T", stmt.Expression)
}
if exp.Token.Literal != tt.operator {
t.Fatalf("exp.Token.Literal is not '%s'. got=%s",
tt.operator, exp.Token.Literal)
}
if !testLiteralExpression(t, exp.Right, tt.value) {
return
}
}
}
func TestParsingInfixExpressions(t *testing.T) {
infixTests := []struct {
input string
leftValue interface{}
operator string
rightValue interface{}
}{
{"5 + 5;", 5, "+", 5},
{"5 - 5;", 5, "-", 5},
{"5 * 5;", 5, "*", 5},
{"5 / 5;", 5, "/", 5},
{"5 > 5;", 5, ">", 5},
{"5 < 5;", 5, "<", 5},
{"5 == 5;", 5, "==", 5},
{"5 != 5;", 5, "!=", 5},
{"foobar + barfoo;", "foobar", "+", "barfoo"},
{"foobar - barfoo;", "foobar", "-", "barfoo"},
{"foobar * barfoo;", "foobar", "*", "barfoo"},
{"foobar / barfoo;", "foobar", "/", "barfoo"},
{"foobar > barfoo;", "foobar", ">", "barfoo"},
{"foobar < barfoo;", "foobar", "<", "barfoo"},
{"foobar == barfoo;", "foobar", "==", "barfoo"},
{"foobar != barfoo;", "foobar", "!=", "barfoo"},
{"true == true", true, "==", true},
{"true != false", true, "!=", false},
{"false == false", false, "==", false},
}
for _, tt := range infixTests {
l := lexer.New(tt.input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
1, len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
if !testInfixExpression(t, stmt.Expression, tt.leftValue,
tt.operator, tt.rightValue) {
return
}
}
}
func TestOperatorPrecedenceParsing(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
"-a * b",
"((-a) * b);\n",
},
{
"!-a",
"(!(-a));\n",
},
{
"a + b + c",
"((a + b) + c);\n",
},
{
"a + b - c",
"((a + b) - c);\n",
},
{
"a * b * c",
"((a * b) * c);\n",
},
{
"a * b / c",
"((a * b) / c);\n",
},
{
"a + b / c",
"(a + (b / c));\n",
},
{
"a + b * c + d / e - f",
"(((a + (b * c)) + (d / e)) - f);\n",
},
{
"3 + 4; -5 * 5",
"(3 + 4);\n((-5) * 5);\n",
},
{
"5 > 4 == 3 < 4",
"((5 > 4) == (3 < 4));\n",
},
{
"5 < 4 != 3 > 4",
"((5 < 4) != (3 > 4));\n",
},
{
"3 + 4 * 5 == 3 * 1 + 4 * 5",
"((3 + (4 * 5)) == ((3 * 1) + (4 * 5)));\n",
},
{
"true",
"true;\n",
},
{
"false",
"false;\n",
},
{
"3 > 5 == false",
"((3 > 5) == false);\n",
},
{
"3 < 5 == true",
"((3 < 5) == true);\n",
},
{
"1 + (2 + 3) + 4",
"((1 + (2 + 3)) + 4);\n",
},
{
"(5 + 5) * 2",
"((5 + 5) * 2);\n",
},
{
"2 / (5 + 5)",
"(2 / (5 + 5));\n",
},
{
"(5 + 5) * 2 * (5 + 5)",
"(((5 + 5) * 2) * (5 + 5));\n",
},
{
"-(5 + 5)",
"(-(5 + 5));\n",
},
{
"!(true == true)",
"(!(true == true));\n",
},
{
"a + add(b * c) + d",
"((a + add((b * c))) + d);\n",
},
{
"add(a, b, 1, 2 * 3, 4 + 5, add(6, 7 * 8))",
"add(a, b, 1, (2 * 3), (4 + 5), add(6, (7 * 8)));\n",
},
{
"add(a + b + c * d / f + g)",
"add((((a + b) + ((c * d) / f)) + g));\n",
},
{
"a * [1, 2, 3, 4][b * c] * d",
"((a * ([1, 2, 3, 4][(b * c)])) * d);\n",
},
{
"add(a * b[2], b[1], 2 * [1, 2][1])",
"add((a * (b[2])), (b[1]), (2 * ([1, 2][1])));\n",
},
}
for _, tt := range tests {
l := lexer.New(tt.input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
actual := program.String()
if actual != tt.expected {
t.Errorf("expected=%q, got=%q", tt.expected, actual)
}
}
}
func TestBooleanExpression(t *testing.T) {
tests := []struct {
input string
expectedBoolean bool
}{
{"true;", true},
{"false;", false},
}
for _, tt := range tests {
l := lexer.New(tt.input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program has not enough statements. got=%d",
len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
boolean, ok := stmt.Expression.(*ast.Boolean)
if !ok {
t.Fatalf("exp not *ast.Boolean. got=%T", stmt.Expression)
}
if boolean.Value != tt.expectedBoolean {
t.Errorf("boolean.Value not %t. got=%t", tt.expectedBoolean,
boolean.Value)
}
}
}
func TestIfExpression(t *testing.T) {
input := `if (x < y) { x }`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
1, len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
exp, ok := stmt.Expression.(*ast.IfExpression)
if !ok {
t.Fatalf("stmt.Expression is not ast.IfExpression. got=%T",
stmt.Expression)
}
if !testInfixExpression(t, exp.Condition, "x", "<", "y") {
return
}
if len(exp.Consequence.Statements) != 1 {
t.Errorf("consequence is not 1 statements. got=%d\n",
len(exp.Consequence.Statements))
}
consequence, ok := exp.Consequence.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T",
exp.Consequence.Statements[0])
}
if !testIdentifier(t, consequence.Expression, "x") {
return
}
if exp.Alternative != nil {
t.Errorf("exp.Alternative.Statements was not nil. got=%+v", exp.Alternative)
}
}
func TestIfElseExpression(t *testing.T) {
input := `if (x < y) { x } else { y }`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
1, len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
exp, ok := stmt.Expression.(*ast.IfExpression)
if !ok {
t.Fatalf("stmt.Expression is not ast.IfExpression. got=%T", stmt.Expression)
}
if !testInfixExpression(t, exp.Condition, "x", "<", "y") {
return
}
if len(exp.Consequence.Statements) != 1 {
t.Errorf("consequence is not 1 statements. got=%d\n",
len(exp.Consequence.Statements))
}
consequence, ok := exp.Consequence.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T",
exp.Consequence.Statements[0])
}
if !testIdentifier(t, consequence.Expression, "x") {
return
}
if len(exp.Alternative.Statements) != 1 {
t.Errorf("exp.Alternative.Statements does not contain 1 statements. got=%d\n",
len(exp.Alternative.Statements))
}
alternative, ok := exp.Alternative.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T",
exp.Alternative.Statements[0])
}
if !testIdentifier(t, alternative.Expression, "y") {
return
}
}
func TestFunctionLiteralParsing(t *testing.T) {
input := `fn(x, y) { x + y; }`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
1, len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
function, ok := stmt.Expression.(*ast.FunctionLiteral)
if !ok {
t.Fatalf("stmt.Expression is not ast.FunctionLiteral. got=%T",
stmt.Expression)
}
if len(function.Parameters) != 2 {
t.Fatalf("function literal parameters wrong. want 2, got=%d\n",
len(function.Parameters))
}
testLiteralExpression(t, function.Parameters[0], "x")
testLiteralExpression(t, function.Parameters[1], "y")
if len(function.Body.Statements) != 1 {
t.Fatalf("function.Body.Statements has not 1 statements. got=%d\n",
len(function.Body.Statements))
}
bodyStmt, ok := function.Body.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("function body stmt is not ast.ExpressionStatement. got=%T",
function.Body.Statements[0])
}
testInfixExpression(t, bodyStmt.Expression, "x", "+", "y")
}
func TestFunctionParameterParsing(t *testing.T) {
tests := []struct {
input string
expectedParams []string
}{
{input: "fn() {};", expectedParams: []string{}},
{input: "fn(x) {};", expectedParams: []string{"x"}},
{input: "fn(x, y, z) {};", expectedParams: []string{"x", "y", "z"}},
}
for _, tt := range tests {
l := lexer.New(tt.input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt := program.Statements[0].(*ast.ExpressionStatement)
function := stmt.Expression.(*ast.FunctionLiteral)
if len(function.Parameters) != len(tt.expectedParams) {
t.Errorf("length parameters wrong. want %d, got=%d\n",
len(tt.expectedParams), len(function.Parameters))
}
for i, ident := range tt.expectedParams {
testLiteralExpression(t, function.Parameters[i], ident)
}
}
}
func TestCallExpressionParsing(t *testing.T) {
input := "add(1, 2 * 3, 4 + 5);"
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
1, len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("stmt is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
exp, ok := stmt.Expression.(*ast.CallExpression)
if !ok {
t.Fatalf("stmt.Expression is not ast.CallExpression. got=%T",
stmt.Expression)
}
if !testIdentifier(t, exp.Function, "add") {
return
}
if len(exp.Arguments) != 3 {
t.Fatalf("wrong length of arguments. got=%d", len(exp.Arguments))
}
testLiteralExpression(t, exp.Arguments[0], 1)
testInfixExpression(t, exp.Arguments[1], 2, "*", 3)
testInfixExpression(t, exp.Arguments[2], 4, "+", 5)
}
func TestCallExpressionParameterParsing(t *testing.T) {
tests := []struct {
input string
expectedIdent string
expectedArgs []string
}{
{
input: "add();",
expectedIdent: "add",
expectedArgs: []string{},
},
{
input: "add(1);",
expectedIdent: "add",
expectedArgs: []string{"1"},
},
{
input: "add(1, 2 * 3, 4 + 5);",
expectedIdent: "add",
expectedArgs: []string{"1", "(2 * 3)", "(4 + 5)"},
},
}
for _, tt := range tests {
l := lexer.New(tt.input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt := program.Statements[0].(*ast.ExpressionStatement)
exp, ok := stmt.Expression.(*ast.CallExpression)
if !ok {
t.Fatalf("stmt.Expression is not ast.CallExpression. got=%T",
stmt.Expression)
}
if !testIdentifier(t, exp.Function, tt.expectedIdent) {
return
}
if len(exp.Arguments) != len(tt.expectedArgs) {
t.Fatalf("wrong number of arguments. want=%d, got=%d",
len(tt.expectedArgs), len(exp.Arguments))
}
for i, arg := range tt.expectedArgs {
if exp.Arguments[i].String() != arg {
t.Errorf("argument %d wrong. want=%q, got=%q", i,
arg, exp.Arguments[i].String())
}
}
}
}
func TestStringLiteralExpression(t *testing.T) {
input := `"hello world";`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt := program.Statements[0].(*ast.ExpressionStatement)
literal, ok := stmt.Expression.(*ast.StringLiteral)
if !ok {
t.Fatalf("exp not *ast.StringLiteral. got=%T", stmt.Expression)
}
if literal.Token.Literal != "hello world" {
t.Errorf("literal.Token.Literal not %q. got=%q", "hello world", literal.Token.Literal)
}
}
func TestParsingEmptyArrayLiterals(t *testing.T) {
input := "[]"
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
array, ok := stmt.Expression.(*ast.ArrayLiteral)
if !ok {
t.Fatalf("exp not ast.ArrayLiteral. got=%T", stmt.Expression)
}
if len(array.Elements) != 0 {
t.Errorf("len(array.Elements) not 0. got=%d", len(array.Elements))
}
}
func TestParsingArrayLiterals(t *testing.T) {
input := "[1, 2 * 2, 3 + 3]"
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
array, ok := stmt.Expression.(*ast.ArrayLiteral)
if !ok {
t.Fatalf("exp not ast.ArrayLiteral. got=%T", stmt.Expression)
}
if len(array.Elements) != 3 {
t.Fatalf("len(array.Elements) not 3. got=%d", len(array.Elements))
}
testIntegerLiteral(t, array.Elements[0], 1)
testInfixExpression(t, array.Elements[1], 2, "*", 2)
testInfixExpression(t, array.Elements[2], 3, "+", 3)
}
func TestParsingIndexExpressions(t *testing.T) {
input := "myArray[1 + 1]"
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
indexExp, ok := stmt.Expression.(*ast.IndexExpression)
if !ok {
t.Fatalf("exp not *ast.IndexExpression. got=%T", stmt.Expression)
}
if !testIdentifier(t, indexExp.Left, "myArray") {
return
}
if !testInfixExpression(t, indexExp.Index, 1, "+", 1) {
return
}
}
func TestParsingEmptyHashLiteral(t *testing.T) {
input := "{}"
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt := program.Statements[0].(*ast.ExpressionStatement)
hash, ok := stmt.Expression.(*ast.HashLiteral)
if !ok {
t.Fatalf("exp is not ast.HashLiteral. got=%T", stmt.Expression)
}
if len(hash.Pairs) != 0 {
t.Errorf("hash.Pairs has wrong length. got=%d", len(hash.Pairs))
}
}
func TestParsingHashLiteralsStringKeys(t *testing.T) {
input := `{"one": 1, "two": 2, "three": 3}`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt := program.Statements[0].(*ast.ExpressionStatement)
hash, ok := stmt.Expression.(*ast.HashLiteral)
if !ok {
t.Fatalf("exp is not ast.HashLiteral. got=%T", stmt.Expression)
}
expected := map[string]int64{
"one": 1,
"two": 2,
"three": 3,
}
if len(hash.Pairs) != len(expected) {
t.Errorf("hash.Pairs has wrong length. got=%d", len(hash.Pairs))
}
for key, value := range hash.Pairs {
literal, ok := key.(*ast.StringLiteral)
if !ok {
t.Errorf("key is not ast.StringLiteral. got=%T", key)
continue
}
expectedValue := expected[literal.String()]
testIntegerLiteral(t, value, expectedValue)
}
}
func TestParsingHashLiteralsBooleanKeys(t *testing.T) {
input := `{true: 1, false: 2}`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt := program.Statements[0].(*ast.ExpressionStatement)
hash, ok := stmt.Expression.(*ast.HashLiteral)
if !ok {
t.Fatalf("exp is not ast.HashLiteral. got=%T", stmt.Expression)
}
expected := map[string]int64{
"true": 1,
"false": 2,
}
if len(hash.Pairs) != len(expected) {
t.Errorf("hash.Pairs has wrong length. got=%d", len(hash.Pairs))
}
for key, value := range hash.Pairs {
boolean, ok := key.(*ast.Boolean)
if !ok {
t.Errorf("key is not ast.BooleanLiteral. got=%T", key)
continue
}
expectedValue := expected[boolean.String()]
testIntegerLiteral(t, value, expectedValue)
}
}
func TestParsingHashLiteralsIntegerKeys(t *testing.T) {
input := `{1: 1, 2: 2, 3: 3}`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt := program.Statements[0].(*ast.ExpressionStatement)
hash, ok := stmt.Expression.(*ast.HashLiteral)
if !ok {
t.Fatalf("exp is not ast.HashLiteral. got=%T", stmt.Expression)
}
expected := map[string]int64{
"1": 1,
"2": 2,
"3": 3,
}
if len(hash.Pairs) != len(expected) {
t.Errorf("hash.Pairs has wrong length. got=%d", len(hash.Pairs))
}
for key, value := range hash.Pairs {
integer, ok := key.(*ast.IntegerLiteral)
if !ok {
t.Errorf("key is not ast.IntegerLiteral. got=%T", key)
continue
}
expectedValue := expected[integer.String()]
testIntegerLiteral(t, value, expectedValue)
}
}
func TestParsingHashLiteralsWithExpressions(t *testing.T) {
input := `{"one": 0 + 1, "two": 10 - 8, "three": 15 / 5}`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
stmt := program.Statements[0].(*ast.ExpressionStatement)
hash, ok := stmt.Expression.(*ast.HashLiteral)
if !ok {
t.Fatalf("exp is not ast.HashLiteral. got=%T", stmt.Expression)
}
if len(hash.Pairs) != 3 {
t.Errorf("hash.Pairs has wrong length. got=%d", len(hash.Pairs))
}
tests := map[string]func(ast.Expression){
"one": func(e ast.Expression) {
testInfixExpression(t, e, 0, "+", 1)
},
"two": func(e ast.Expression) {
testInfixExpression(t, e, 10, "-", 8)
},
"three": func(e ast.Expression) {
testInfixExpression(t, e, 15, "/", 5)
},
}
for key, value := range hash.Pairs {
literal, ok := key.(*ast.StringLiteral)
if !ok {
t.Errorf("key is not ast.StringLiteral. got=%T", key)
continue
}
testFunc, ok := tests[literal.String()]
if !ok {
t.Errorf("No test function for key %q found", literal.String())
continue
}
testFunc(value)
}
}
func TestMacroLiteralParsing(t *testing.T) {
input := `macro(x, y) { x + y; }`
l := lexer.New(input)
p := New(l)
program := p.ParseProgram()
checkParserErrors(t, p)
if len(program.Statements) != 1 {
t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
1, len(program.Statements))
}
stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
program.Statements[0])
}
macro, ok := stmt.Expression.(*ast.MacroLiteral)
if !ok {
t.Fatalf("stmt.Expression is not ast.MacroLiteral. got=%T",
stmt.Expression)
}
if len(macro.Parameters) != 2 {
t.Fatalf("macro literal parameters wrong. want 2, got=%d\n",
len(macro.Parameters))
}
testLiteralExpression(t, macro.Parameters[0], "x")
testLiteralExpression(t, macro.Parameters[1], "y")
if len(macro.Body.Statements) != 1 {
t.Fatalf("macro.Body.Statements has not 1 statements. got=%d\n",
len(macro.Body.Statements))
}
bodyStmt, ok := macro.Body.Statements[0].(*ast.ExpressionStatement)
if !ok {
t.Fatalf("macro body stmt is not ast.ExpressionStatement. got=%T",
macro.Body.Statements[0])
}
testInfixExpression(t, bodyStmt.Expression, "x", "+", "y")
}
monkey/parser/parser_tracing.go
package parser
import (
"fmt"
"strings"
)
const traceIdentPlaceholder string = "\t"
var traceLevel int = 0
func identLevel() string {
return strings.Repeat(traceIdentPlaceholder, traceLevel-1)
}
func tracePrint(fs string) {
fmt.Printf("%s%s\n", identLevel(), fs)
}
func incIdent() { traceLevel += 1 }
func decIdent() { traceLevel -= 1 }
func trace(msg string) string {
incIdent()
tracePrint("BEGIN " + msg)
return msg
}
func untrace(msg string) {
tracePrint("END " + msg)
decIdent()
}
monkey/object/object.go
package object
import (
"bytes"
"fmt"
"hash/fnv"
"monkey/ast"
"strings"
)
const (
NULL_OBJ = "NULL"
ERROR_OBJ = "ERROR"
INTEGER_OBJ = "INTEGER"
BOOLEAN_OBJ = "BOOLEAN"
STRING_OBJ = "STRING"
RETURN_VALUE_OBJ = "RETURN_VALUE"
FUNCTION_OBJ = "FUNCTION"
BUILTIN_OBJ = "BUILTIN"
ARRAY_OBJ = "ARRAY"
HASH_OBJ = "HASH"
QUOTE_OBJ = "QUOTE"
MACRO_OBJ = "MACRO"
)
type ObjectType string
type HashKey struct {
Type ObjectType
Value uint64
}
type Hashable interface {
HashKey() HashKey
}
type Object interface {
Type() ObjectType
Inspect() string
}
type Integer struct {
Value int64
}
func (i *Integer) Type() ObjectType { return INTEGER_OBJ }
func (i *Integer) Inspect() string { return fmt.Sprintf("%d", i.Value) }
func (i *Integer) HashKey() HashKey {
return HashKey{Type: i.Type(), Value: uint64(i.Value)}
}
type Boolean struct {
Value bool
}
func (b *Boolean) Type() ObjectType { return BOOLEAN_OBJ }
func (b *Boolean) Inspect() string { return fmt.Sprintf("%t", b.Value) }
func (b *Boolean) HashKey() HashKey {
var value uint64
if b.Value {
value = 1
} else {
value = 0
}
return HashKey{Type: b.Type(), Value: value}
}
type Null struct{}
func (n *Null) Type() ObjectType { return NULL_OBJ }
func (n *Null) Inspect() string { return "null" }
type ReturnValue struct {
Value Object
}
func (rv *ReturnValue) Type() ObjectType { return RETURN_VALUE_OBJ }
func (rv *ReturnValue) Inspect() string { return rv.Value.Inspect() }
type Error struct {
Message string
}
func (e *Error) Type() ObjectType { return ERROR_OBJ }
func (e *Error) Inspect() string { return "ERROR: " + e.Message }
type Function struct {
Parameters []*ast.Identifier
Body *ast.BlockStatement
Env *Environment
}
func (f *Function) Type() ObjectType { return FUNCTION_OBJ }
func (f *Function) Inspect() string {
var out bytes.Buffer
params := []string{}
for _, p := range f.Parameters {
params = append(params, p.String())
}
out.WriteString("fn")
out.WriteString("(")
out.WriteString(strings.Join(params, ", "))
out.WriteString(") ")
out.WriteString(f.Body.String())
out.WriteString("\n")
return out.String()
}
type String struct {
Value string
}
func (s *String) Type() ObjectType { return STRING_OBJ }
func (s *String) Inspect() string { return s.Value }
func (s *String) HashKey() HashKey {
h := fnv.New64a()
h.Write([]byte(s.Value))
return HashKey{Type: s.Type(), Value: h.Sum64()}
}
type BuiltinFunction func(args ...Object) Object
type Builtin struct {
Fn BuiltinFunction
}
func (b *Builtin) Type() ObjectType { return BUILTIN_OBJ }
func (b *Builtin) Inspect() string { return "builtin function" }
type Array struct {
Elements []Object
}
func (ao *Array) Type() ObjectType { return ARRAY_OBJ }
func (ao *Array) Inspect() string {
var out bytes.Buffer
elements := []string{}
for _, e := range ao.Elements {
elements = append(elements, e.Inspect())
}
out.WriteString("[")
out.WriteString(strings.Join(elements, ", "))
out.WriteString("]")
return out.String()
}
type HashPair struct {
Key Object
Value Object
}
type Hash struct {
Pairs map[HashKey]HashPair
}
func (h *Hash) Type() ObjectType { return HASH_OBJ }
func (h *Hash) Inspect() string {
var out bytes.Buffer
pairs := []string{}
for _, pair := range h.Pairs {
pairs = append(pairs, fmt.Sprintf("%s: %s",
pair.Key.Inspect(), pair.Value.Inspect()))
}
out.WriteString("{")
out.WriteString(strings.Join(pairs, ", "))
out.WriteString("}")
return out.String()
}
type Quote struct {
Node ast.Node
}
func (q *Quote) Type() ObjectType { return QUOTE_OBJ }
func (q *Quote) Inspect() string {
return "QUOTE(" + q.Node.String() + ")"
}
type Macro struct {
Parameters []*ast.Identifier
Body *ast.BlockStatement
Env *Environment
}
func (m *Macro) Type() ObjectType { return MACRO_OBJ }
func (m *Macro) Inspect() string {
var out bytes.Buffer
params := []string{}
for _, p := range m.Parameters {
params = append(params, p.String())
}
out.WriteString("macro")
out.WriteString("(")
out.WriteString(strings.Join(params, ", "))
out.WriteString(") ")
out.WriteString(m.Body.String())
out.WriteString("\n")
return out.String()
}
monkey/object/object_test.go
package object
import "testing"
func TestStringHashKey(t *testing.T) {
hello1 := &String{Value: "Hello World"}
hello2 := &String{Value: "Hello World"}
diff1 := &String{Value: "My name is johnny"}
diff2 := &String{Value: "My name is johnny"}
if hello1.HashKey() != hello2.HashKey() {
t.Errorf("strings with same content have different hash keys")
}
if diff1.HashKey() != diff2.HashKey() {
t.Errorf("strings with same content have different hash keys")
}
if hello1.HashKey() == diff1.HashKey() {
t.Errorf("strings with different content have same hash keys")
}
}
func TestBooleanHashKey(t *testing.T) {
true1 := &Boolean{Value: true}
true2 := &Boolean{Value: true}
false1 := &Boolean{Value: false}
false2 := &Boolean{Value: false}
if true1.HashKey() != true2.HashKey() {
t.Errorf("trues do not have same hash key")
}
if false1.HashKey() != false2.HashKey() {
t.Errorf("falses do not have same hash key")
}
if true1.HashKey() == false1.HashKey() {
t.Errorf("true has same hash key as false")
}
}
func TestIntegerHashKey(t *testing.T) {
one1 := &Integer{Value: 1}
one2 := &Integer{Value: 1}
two1 := &Integer{Value: 2}
two2 := &Integer{Value: 2}
if one1.HashKey() != one2.HashKey() {
t.Errorf("integers with same content have twoerent hash keys")
}
if two1.HashKey() != two2.HashKey() {
t.Errorf("integers with same content have twoerent hash keys")
}
if one1.HashKey() == two1.HashKey() {
t.Errorf("integers with twoerent content have same hash keys")
}
}
monkey/object/environment.go
package object
type Environment struct {
store map[string]Object
outer *Environment
}
func (e *Environment) Get(name string) (Object, bool) {
obj, ok := e.store[name]
if !ok && e.outer != nil {
obj, ok = e.outer.Get(name)
}
return obj, ok
}
func (e *Environment) Set(name string, val Object) Object {
e.store[name] = val
return val
}
func NewEnvironment() *Environment {
s := make(map[string]Object)
return &Environment{store: s, outer: nil}
}
func NewEnclosedEnvironment(outer *Environment) *Environment {
env := NewEnvironment()
env.outer = outer
return env
}
monkey/evaluator/builtins.go
package evaluator
import (
"fmt"
"monkey/object"
)
var builtins = map[string]*object.Builtin{
"len": &object.Builtin{
Fn: func(args ...object.Object) object.Object {
if len(args) != 1 {
return newError("wrong number of arguments. got=%d, want=1",
len(args))
}
switch arg := args[0].(type) {
case *object.Array:
return &object.Integer{Value: int64(len(arg.Elements))}
case *object.String:
return &object.Integer{Value: int64(len(arg.Value))}
default:
return newError("argument to `len` not supported, got %s",
args[0].Type())
}
},
},
"puts": &object.Builtin{
Fn: func(args ...object.Object) object.Object {
for _, arg := range args {
fmt.Println(arg.Inspect())
}
return NULL
},
},
"first": &object.Builtin{
Fn: func(args ...object.Object) object.Object {
if len(args) != 1 {
return newError("wrong number of arguments. got=%d, want=1",
len(args))
}
if args[0].Type() != object.ARRAY_OBJ {
return newError("argument to `first` must be ARRAY, got %s",
args[0].Type())
}
arr := args[0].(*object.Array)
if len(arr.Elements) > 0 {
return arr.Elements[0]
}
return NULL
},
},
"last": &object.Builtin{
Fn: func(args ...object.Object) object.Object {
if len(args) != 1 {
return newError("wrong number of arguments. got=%d, want=1",
len(args))
}
if args[0].Type() != object.ARRAY_OBJ {
return newError("argument to `last` must be ARRAY, got %s",
args[0].Type())
}
arr := args[0].(*object.Array)
length := len(arr.Elements)
if length > 0 {
return arr.Elements[length-1]
}
return NULL
},
},
"rest": &object.Builtin{
Fn: func(args ...object.Object) object.Object {
if len(args) != 1 {
return newError("wrong number of arguments. got=%d, want=1",
len(args))
}
if args[0].Type() != object.ARRAY_OBJ {
return newError("argument to `rest` must be ARRAY, got %s",
args[0].Type())
}
arr := args[0].(*object.Array)
length := len(arr.Elements)
if length > 0 {
newElements := make([]object.Object, length-1, length-1)
copy(newElements, arr.Elements[1:length])
return &object.Array{Elements: newElements}
}
return NULL
},
},
"push": &object.Builtin{
Fn: func(args ...object.Object) object.Object {
if len(args) != 2 {
return newError("wrong number of arguments. got=%d, want=2",
len(args))
}
if args[0].Type() != object.ARRAY_OBJ {
return newError("argument to `push` must be ARRAY, got %s",
args[0].Type())
}
arr := args[0].(*object.Array)
length := len(arr.Elements)
newElements := make([]object.Object, length+1, length+1)
copy(newElements, arr.Elements)
newElements[length] = args[1]
return &object.Array{Elements: newElements}
},
},
}
monkey/evaluator/evaluator.go
package evaluator
import (
"fmt"
"monkey/ast"
"monkey/object"
)
var (
NULL = &object.Null{}
TRUE = &object.Boolean{Value: true}
FALSE = &object.Boolean{Value: false}
)
func Eval(node ast.Node, env *object.Environment) object.Object {
switch node := node.(type) {
// 语句
case *ast.Program:
return evalProgram(node, env)
case *ast.BlockStatement:
return evalBlockStatement(node, env)
case *ast.ExpressionStatement:
return Eval(node.Expression, env)
case *ast.ReturnStatement:
val := Eval(node.ReturnValue, env)
if isError(val) {
return val
}
return &object.ReturnValue{Value: val}
case *ast.LetStatement:
val := Eval(node.Value, env)
if isError(val) {
return val
}
env.Set(node.Name.Token.Literal, val)
// 表达式
case *ast.IntegerLiteral:
return &object.Integer{Value: node.Value}
case *ast.StringLiteral:
return &object.String{Value: node.Token.Literal}
case *ast.Boolean:
return nativeBoolToBooleanObject(node.Value)
case *ast.PrefixExpression:
right := Eval(node.Right, env)
if isError(right) {
return right
}
return evalPrefixExpression(node.Token.Literal, right)
case *ast.InfixExpression:
left := Eval(node.Left, env)
if isError(left) {
return left
}
right := Eval(node.Right, env)
if isError(right) {
return right
}
return evalInfixExpression(node.Token.Literal, left, right)
case *ast.IfExpression:
return evalIfExpression(node, env)
case *ast.Identifier:
return evalIdentifier(node, env)
case *ast.FunctionLiteral:
params := node.Parameters
body := node.Body
return &object.Function{Parameters: params, Env: env, Body: body}
case *ast.CallExpression:
if node.Function.String() == "quote" {
return quote(node.Arguments[0], env)
}
function := Eval(node.Function, env) //得到object.Function对象
if isError(function) {
return function
}
args := evalExpressions(node.Arguments, env) //函数实参expression,生成object
if len(args) == 1 && isError(args[0]) {
return args[0]
}
return applyFunction(function, args)
case *ast.ArrayLiteral:
elements := evalExpressions(node.Elements, env)
if len(elements) == 1 && isError(elements[0]) {
return elements[0]
}
return &object.Array{Elements: elements}
case *ast.IndexExpression:
left := Eval(node.Left, env)
if isError(left) {
return left
}
index := Eval(node.Index, env)
if isError(index) {
return index
}
return evalIndexExpression(left, index)
case *ast.HashLiteral:
return evalHashLiteral(node, env)
}
return nil
}
func evalProgram(program *ast.Program, env *object.Environment) object.Object {
var result object.Object
for _, statement := range program.Statements { //遍历语句
result = Eval(statement, env)
switch result := result.(type) {
case *object.ReturnValue: //return语句
return result.Value
case *object.Error: //错误语句
return result
}
}
return result
}
func evalBlockStatement(block *ast.BlockStatement, env *object.Environment) object.Object {
var result object.Object
for _, statement := range block.Statements {
result = Eval(statement, env)
if result != nil {
rt := result.Type()
if rt == object.RETURN_VALUE_OBJ || rt == object.ERROR_OBJ { //return对象或者error对象
return result
}
}
}
return result
}
func evalPrefixExpression(operator string, right object.Object) object.Object {
switch operator {
case "!":
return evalBangOperatorExpression(right)
case "-":
return evalMinusPrefixOperatorExpression(right)
default:
return newError("unknown operator: %s%s", operator, right.Type())
}
}
func evalBangOperatorExpression(right object.Object) object.Object {
switch right {
case TRUE:
return FALSE
case FALSE:
return TRUE
case NULL:
return TRUE
default:
return FALSE
}
}
func evalMinusPrefixOperatorExpression(right object.Object) object.Object {
if right.Type() != object.INTEGER_OBJ {
return newError("unknown operator: -%s", right.Type())
}
value := right.(*object.Integer).Value
return &object.Integer{Value: -value}
}
func evalInfixExpression(
operator string,
left, right object.Object,
) object.Object {
switch {
case left.Type() == object.INTEGER_OBJ && right.Type() == object.INTEGER_OBJ:
return evalIntegerInfixExpression(operator, left, right)
case left.Type() == object.STRING_OBJ && right.Type() == object.STRING_OBJ:
return evalStringInfixExpression(operator, left, right)
case operator == "==":
return nativeBoolToBooleanObject(left == right)
case operator == "!=":
return nativeBoolToBooleanObject(left != right)
case left.Type() != right.Type():
return newError("type mismatch: %s %s %s",
left.Type(), operator, right.Type())
default:
return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type())
}
}
func evalStringInfixExpression(
operator string,
left, right object.Object,
) object.Object {
if operator != "+" {
return newError("unknown operator: %s %s %s",
left.Type(), operator, right.Type())
}
leftVal := left.(*object.String).Value
rightVal := right.(*object.String).Value
return &object.String{Value: leftVal + rightVal}
}
func evalIfExpression(ie *ast.IfExpression, env *object.Environment) object.Object {
condition := Eval(ie.Condition, env)
if isError(condition) {
return condition
}
if isTruthy(condition) {
return Eval(ie.Consequence, env)
} else if ie.Alternative != nil {
return Eval(ie.Alternative, env)
} else {
return NULL
}
}
func evalIdentifier(node *ast.Identifier, env *object.Environment) object.Object {
if val, ok := env.Get(node.Token.Literal); ok {
return val
}
if builtin, ok := builtins[node.Token.Literal]; ok {
return builtin
}
return newError("identifier not found: " + node.Token.Literal)
}
func evalIntegerInfixExpression(
operator string,
left, right object.Object,
) object.Object {
leftVal := left.(*object.Integer).Value
rightVal := right.(*object.Integer).Value
switch operator {
case "+":
return &object.Integer{Value: leftVal + rightVal}
case "-":
return &object.Integer{Value: leftVal - rightVal}
case "*":
return &object.Integer{Value: leftVal * rightVal}
case "/":
return &object.Integer{Value: leftVal / rightVal}
case "<":
return nativeBoolToBooleanObject(leftVal < rightVal)
case ">":
return nativeBoolToBooleanObject(leftVal > rightVal)
case "==":
return nativeBoolToBooleanObject(leftVal == rightVal)
case "!=":
return nativeBoolToBooleanObject(leftVal != rightVal)
default:
return newError("unknown operator: %s %s %s", left.Type(), operator, right.Type())
}
}
func nativeBoolToBooleanObject(input bool) *object.Boolean {
if input {
return TRUE
}
return FALSE
}
func isTruthy(obj object.Object) bool {
switch obj {
case NULL:
return false
case TRUE:
return true
case FALSE:
return false
default:
return true
}
}
func newError(format string, a ...interface{}) *object.Error {
return &object.Error{Message: fmt.Sprintf(format, a...)}
}
func isError(obj object.Object) bool {
if obj != nil {
return obj.Type() == object.ERROR_OBJ
}
return false
}
// 解析实参表达式
func evalExpressions(exps []ast.Expression, env *object.Environment) []object.Object {
var result []object.Object
for _, e := range exps {
evaluated := Eval(e, env)
if isError(evaluated) {
return []object.Object{evaluated}
}
result = append(result, evaluated)
}
return result
}
// 应用实参对象,计算函数值
func applyFunction(fn object.Object, args []object.Object) object.Object {
switch fn := fn.(type) {
case *object.Function:
extendedEnv := extendFunctionEnv(fn, args)
evaluated := Eval(fn.Body, extendedEnv)
return unwrapReturnValue(evaluated)
case *object.Builtin:
return fn.Fn(args...)
default:
return newError("not a function: %s", fn.Type())
}
}
// 扩展函数对象中环境变量
// 传入实参对象,关联函数定义中的标识符参数
func extendFunctionEnv(fn *object.Function, args []object.Object) *object.Environment {
env := object.NewEnclosedEnvironment(fn.Env)
for paramIdx, param := range fn.Parameters {
env.Set(param.Token.Literal, args[paramIdx])
}
return env
}
func unwrapReturnValue(obj object.Object) object.Object {
if returnValue, ok := obj.(*object.ReturnValue); ok {
return returnValue.Value
}
return obj
}
func evalIndexExpression(left, index object.Object) object.Object {
switch {
case left.Type() == object.ARRAY_OBJ && index.Type() == object.INTEGER_OBJ:
return evalArrayIndexExpression(left, index)
case left.Type() == object.HASH_OBJ:
return evalHashIndexExpression(left, index)
default:
return newError("index operator not supported: %s", left.Type())
}
}
func evalArrayIndexExpression(array, index object.Object) object.Object {
arrayObject := array.(*object.Array)
idx := index.(*object.Integer).Value
max := int64(len(arrayObject.Elements) - 1)
if idx < 0 || idx > max {
return NULL
}
return arrayObject.Elements[idx]
}
func evalHashLiteral(
node *ast.HashLiteral,
env *object.Environment,
) object.Object {
pairs := make(map[object.HashKey]object.HashPair)
for keyNode, valueNode := range node.Pairs {
key := Eval(keyNode, env)
if isError(key) {
return key
}
hashKey, ok := key.(object.Hashable)
if !ok {
return newError("unusable as hash key: %s", key.Type())
}
value := Eval(valueNode, env)
if isError(value) {
return value
}
hashed := hashKey.HashKey()
pairs[hashed] = object.HashPair{Key: key, Value: value}
}
return &object.Hash{Pairs: pairs}
}
func evalHashIndexExpression(hash, index object.Object) object.Object {
hashObject := hash.(*object.Hash)
key, ok := index.(object.Hashable)
if !ok {
return newError("unusable as hash key: %s", index.Type())
}
pair, ok := hashObject.Pairs[key.HashKey()]
if !ok {
return NULL
}
return pair.Value
}
monkey/evaluator/evaluator_test.go
package evaluator
import (
"monkey/lexer"
"monkey/object"
"monkey/parser"
"testing"
)
func testEval(input string) object.Object {
l := lexer.New(input)
p := parser.New(l)
program := p.ParseProgram()
env := object.NewEnvironment()
return Eval(program, env)
}
func testIntegerObject(t *testing.T, obj object.Object, expected int64) bool {
result, ok := obj.(*object.Integer)
if !ok {
t.Errorf("object is not Integer. got=%T (%+v)", obj, obj)
return false
}
if result.Value != expected {
t.Errorf("object has wrong value. got=%d, want=%d",
result.Value, expected)
return false
}
return true
}
func testBooleanObject(t *testing.T, obj object.Object, expected bool) bool {
result, ok := obj.(*object.Boolean)
if !ok {
t.Errorf("object is not Boolean. got=%T (%+v)", obj, obj)
return false
}
if result.Value != expected {
t.Errorf("object has wrong value. got=%t, want=%t",
result.Value, expected)
return false
}
return true
}
func testNullObject(t *testing.T, obj object.Object) bool {
if obj != NULL {
t.Errorf("object is not NULL. got=%T (%+v)", obj, obj)
return false
}
return true
}
func TestEvalIntegerExpression(t *testing.T) {
tests := []struct {
input string
expected int64
}{
{"5", 5},
{"10", 10},
{"-5", -5},
{"-10", -10},
{"5 + 5 + 5 + 5 - 10", 10},
{"2 * 2 * 2 * 2 * 2", 32},
{"-50 + 100 + -50", 0},
{"5 * 2 + 10", 20},
{"5 + 2 * 10", 25},
{"20 + 2 * -10", 0},
{"50 / 2 * 2 + 10", 60},
{"2 * (5 + 10)", 30},
{"3 * 3 * 3 + 10", 37},
{"3 * (3 * 3) + 10", 37},
{"(5 + 10 * 2 + 15 / 3) * 2 + -10", 50},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
testIntegerObject(t, evaluated, tt.expected)
}
}
func TestEvalBooleanExpression(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{"true", true},
{"false", false},
{"1 < 2", true},
{"1 > 2", false},
{"1 < 1", false},
{"1 > 1", false},
{"1 == 1", true},
{"1 != 1", false},
{"1 == 2", false},
{"1 != 2", true},
{"true == true", true},
{"false == false", true},
{"true == false", false},
{"true != false", true},
{"false != true", true},
{"(1 < 2) == true", true},
{"(1 < 2) == false", false},
{"(1 > 2) == true", false},
{"(1 > 2) == false", true},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
testBooleanObject(t, evaluated, tt.expected)
}
}
func TestBangOperator(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{"!true", false},
{"!false", true},
{"!5", false},
{"!!true", true},
{"!!false", false},
{"!!5", true},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
testBooleanObject(t, evaluated, tt.expected)
}
}
func TestIfElseExpressions(t *testing.T) {
tests := []struct {
input string
expected interface{}
}{
{"if (true) { 10 }", 10},
{"if (false) { 10 }", nil},
{"if (1) { 10 }", 10},
{"if (1 < 2) { 10 }", 10},
{"if (1 > 2) { 10 }", nil},
{"if (1 > 2) { 10 } else { 20 }", 20},
{"if (1 < 2) { 10 } else { 20 }", 10},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
integer, ok := tt.expected.(int)
if ok {
testIntegerObject(t, evaluated, int64(integer))
} else {
testNullObject(t, evaluated)
}
}
}
func TestReturnStatements(t *testing.T) {
tests := []struct {
input string
expected int64
}{
{"return 10;", 10},
{"return 10; 9;", 10},
{"return 2 * 5; 9;", 10},
{"9; return 2 * 5; 9;", 10},
{"if (10 > 1) { return 10; }", 10},
{`if (10 > 1) {
if (10 > 1) {
return 10;
}
return 1;
}`, 10},
{`let f = fn(x) {
return x;
x + 10;
}; f(10);`, 10},
{`let f = fn(x) {
let result = x + 10;
return result;
return 10;
}; f(10);`, 20},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
testIntegerObject(t, evaluated, tt.expected)
}
}
func TestErrorHandling(t *testing.T) {
tests := []struct {
input string
expectedMessage string
}{
{
"5 + true;",
"type mismatch: INTEGER + BOOLEAN",
},
{
"5 + true; 5;",
"type mismatch: INTEGER + BOOLEAN",
},
{
"-true",
"unknown operator: -BOOLEAN",
},
{
"true + false;",
"unknown operator: BOOLEAN + BOOLEAN",
},
{
"true + false + true + false;",
"unknown operator: BOOLEAN + BOOLEAN",
},
{
"5; true + false; 5",
"unknown operator: BOOLEAN + BOOLEAN",
},
{
`"Hello" - "World"`,
"unknown operator: STRING - STRING",
},
{
"if (10 > 1) { true + false; }",
"unknown operator: BOOLEAN + BOOLEAN",
},
{`if (10 > 1) {
if (10 > 1) {
return true + false;
}
return 1;
}`, "unknown operator: BOOLEAN + BOOLEAN",
},
{
"foobar",
"identifier not found: foobar",
},
{
`{"name": "Monkey"}[fn(x) { x }];`,
"unusable as hash key: FUNCTION",
},
{
`999[1]`,
"index operator not supported: INTEGER",
},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
errObj, ok := evaluated.(*object.Error)
if !ok {
t.Errorf("no error object returned. got=%T(%+v)",
evaluated, evaluated)
continue
}
if errObj.Message != tt.expectedMessage {
t.Errorf("wrong error message. expected=%q, got=%q",
tt.expectedMessage, errObj.Message)
}
}
}
func TestLetStatements(t *testing.T) {
tests := []struct {
input string
expected int64
}{
{"let a = 5; a;", 5},
{"let a = 5 * 5; a;", 25},
{"let a = 5; let b = a; b;", 5},
{"let a = 5; let b = a; let c = a + b + 5; c;", 15},
}
for _, tt := range tests {
testIntegerObject(t, testEval(tt.input), tt.expected)
}
}
func TestFunctionObject(t *testing.T) {
input := "fn(x) { x + 2; };"
evaluated := testEval(input)
fn, ok := evaluated.(*object.Function)
if !ok {
t.Fatalf("object is not Function. got=%T (%+v)", evaluated, evaluated)
}
if len(fn.Parameters) != 1 {
t.Fatalf("function has wrong parameters. Parameters=%+v",
fn.Parameters)
}
if fn.Parameters[0].String() != "x" {
t.Fatalf("parameter is not 'x'. got=%q", fn.Parameters[0])
}
expectedBody := "{\n\t(x + 2);\n}"
if fn.Body.String() != expectedBody {
t.Fatalf("body is not %q. got=%q", expectedBody, fn.Body.String())
}
}
func TestFunctionApplication(t *testing.T) {
tests := []struct {
input string
expected int64
}{
{"let identity = fn(x) { x; }; identity(5);", 5},
{"let identity = fn(x) { return x; }; identity(5);", 5},
{"let double = fn(x) { x * 2; }; double(5);", 10},
{"let add = fn(x, y) { x + y; }; add(5, 5);", 10},
{"let add = fn(x, y) { x + y; }; add(5 + 5, add(5, 5));", 20},
{"fn(x) { x; }(5)", 5},
}
for _, tt := range tests {
testIntegerObject(t, testEval(tt.input), tt.expected)
}
}
func TestEnclosingEnvironments(t *testing.T) {
input := `
let first = 10;
let second = 10;
let third = 10;
let ourFunction = fn(first) {
let second = 20;
first + second + third;
};
ourFunction(20) + first + second;`
testIntegerObject(t, testEval(input), 70)
}
func TestStringLiteral(t *testing.T) {
input := `"Hello World!"`
evaluated := testEval(input)
str, ok := evaluated.(*object.String)
if !ok {
t.Fatalf("object is not String. got=%T (%+v)", evaluated, evaluated)
}
if str.Value != "Hello World!" {
t.Errorf("String has wrong value. got=%q", str.Value)
}
}
func TestStringConcatenation(t *testing.T) {
input := `"Hello" + " " + "World!"`
evaluated := testEval(input)
str, ok := evaluated.(*object.String)
if !ok {
t.Fatalf("object is not String. got=%T (%+v)", evaluated, evaluated)
}
if str.Value != "Hello World!" {
t.Errorf("String has wrong value. got=%q", str.Value)
}
}
func TestBuiltinFunctions(t *testing.T) {
tests := []struct {
input string
expected interface{}
}{
{`len("")`, 0},
{`len("four")`, 4},
{`len("hello world")`, 11},
{`len(1)`, "argument to `len` not supported, got INTEGER"},
{`len("one", "two")`, "wrong number of arguments. got=2, want=1"},
{`len([1, 2, 3])`, 3},
{`len([])`, 0},
/*
{`puts("hello", "world!")`, nil},
{`first([1, 2, 3])`, 1},
{`first([])`, nil},
{`first(1)`, "argument to `first` must be ARRAY, got INTEGER"},
{`last([1, 2, 3])`, 3},
{`last([])`, nil},
{`last(1)`, "argument to `last` must be ARRAY, got INTEGER"},
{`rest([1, 2, 3])`, []int{2, 3}},
{`rest([])`, nil},
{`push([], 1)`, []int{1}},
{`push(1, 1)`, "argument to `push` must be ARRAY, got INTEGER"},
*/
}
for _, tt := range tests {
evaluated := testEval(tt.input)
switch expected := tt.expected.(type) {
case int:
testIntegerObject(t, evaluated, int64(expected))
case nil:
testNullObject(t, evaluated)
case string:
errObj, ok := evaluated.(*object.Error)
if !ok {
t.Errorf("object is not Error. got=%T (%+v)",
evaluated, evaluated)
continue
}
if errObj.Message != expected {
t.Errorf("wrong error message. expected=%q, got=%q",
expected, errObj.Message)
}
case []int:
array, ok := evaluated.(*object.Array)
if !ok {
t.Errorf("obj not Array. got=%T (%+v)", evaluated, evaluated)
continue
}
if len(array.Elements) != len(expected) {
t.Errorf("wrong num of elements. want=%d, got=%d",
len(expected), len(array.Elements))
continue
}
for i, expectedElem := range expected {
testIntegerObject(t, array.Elements[i], int64(expectedElem))
}
}
}
}
func TestArrayLiterals(t *testing.T) {
input := "[1, 2 * 2, 3 + 3]"
evaluated := testEval(input)
result, ok := evaluated.(*object.Array)
if !ok {
t.Fatalf("object is not Array. got=%T (%+v)", evaluated, evaluated)
}
if len(result.Elements) != 3 {
t.Fatalf("array has wrong num of elements. got=%d",
len(result.Elements))
}
testIntegerObject(t, result.Elements[0], 1)
testIntegerObject(t, result.Elements[1], 4)
testIntegerObject(t, result.Elements[2], 6)
}
func TestArrayIndexExpressions(t *testing.T) {
tests := []struct {
input string
expected interface{}
}{
{
"[1, 2, 3][0]",
1,
},
{
"[1, 2, 3][1]",
2,
},
{
"[1, 2, 3][2]",
3,
},
{
"let i = 0; [1][i];",
1,
},
{
"[1, 2, 3][1 + 1];",
3,
},
{
"let myArray = [1, 2, 3]; myArray[2];",
3,
},
{
"let myArray = [1, 2, 3]; myArray[0] + myArray[1] + myArray[2];",
6,
},
{
"let myArray = [1, 2, 3]; let i = myArray[0]; myArray[i]",
2,
},
{
"[1, 2, 3][3]",
nil,
},
{
"[1, 2, 3][-1]",
nil,
},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
integer, ok := tt.expected.(int)
if ok {
testIntegerObject(t, evaluated, int64(integer))
} else {
testNullObject(t, evaluated)
}
}
}
func TestHashLiterals(t *testing.T) {
input := `let two = "two";
{
"one": 10 - 9,
two: 1 + 1,
"thr" + "ee": 6 / 2,
4: 4,
true: 5,
false: 6
}`
evaluated := testEval(input)
result, ok := evaluated.(*object.Hash)
if !ok {
t.Fatalf("Eval didn't return Hash. got=%T (%+v)", evaluated, evaluated)
}
expected := map[object.HashKey]int64{
(&object.String{Value: "one"}).HashKey(): 1,
(&object.String{Value: "two"}).HashKey(): 2,
(&object.String{Value: "three"}).HashKey(): 3,
(&object.Integer{Value: 4}).HashKey(): 4,
TRUE.HashKey(): 5,
FALSE.HashKey(): 6,
}
if len(result.Pairs) != len(expected) {
t.Fatalf("Hash has wrong num of pairs. got=%d", len(result.Pairs))
}
for expectedKey, expectedValue := range expected {
pair, ok := result.Pairs[expectedKey]
if !ok {
t.Errorf("no pair for given key in Pairs")
}
testIntegerObject(t, pair.Value, expectedValue)
}
}
func TestHashIndexExpressions(t *testing.T) {
tests := []struct {
input string
expected interface{}
}{
{
`{"foo": 5}["foo"]`,
5,
},
{
`{"foo": 5}["bar"]`,
nil,
},
{
`let key = "foo"; {"foo": 5}[key]`,
5,
},
{
`{}["foo"]`,
nil,
},
{
`{5: 5}[5]`,
5,
},
{
`{true: 5}[true]`,
5,
},
{
`{false: 5}[false]`,
5,
},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
integer, ok := tt.expected.(int)
if ok {
testIntegerObject(t, evaluated, int64(integer))
} else {
testNullObject(t, evaluated)
}
}
}
monkey/evaluator/quote_unquote.go
package evaluator
import (
"monkey/ast"
"monkey/object"
"monkey/token"
"strconv"
)
func quote(node ast.Node, env *object.Environment) object.Object {
node = evalUnquoteCalls(node, env)
return &object.Quote{Node: node}
}
func evalUnquoteCalls(quoted ast.Node, env *object.Environment) ast.Node {
return ast.Modify(quoted, func(node ast.Node) ast.Node {
if isUnquoteCall(node) {
call, _ := node.(*ast.CallExpression)
unquoted := Eval(call.Arguments[0], env)
return convertObjectToASTNode(unquoted)
}
return node
})
}
func isUnquoteCall(node ast.Node) bool {
call, ok := node.(*ast.CallExpression)
if ok {
return call.Function.String() == "unquote" && len(call.Arguments) == 1
}
return false
}
func convertObjectToASTNode(obj object.Object) ast.Node {
switch obj := obj.(type) {
case *object.Integer:
return &ast.IntegerLiteral{Token: token.Token{Type: token.INT, Literal: strconv.Itoa(int(obj.Value))}, Value: obj.Value}
case *object.Boolean:
var t token.Token
if obj.Value {
t = token.Token{Type: token.TRUE, Literal: "true"}
} else {
t = token.Token{Type: token.FALSE, Literal: "false"}
}
return &ast.Boolean{Token: t, Value: obj.Value}
case *object.Quote:
return obj.Node
default:
return nil
}
}
monkey/evaluator/quote_unquote_test.go
package evaluator
import (
"monkey/object"
"testing"
)
func TestQuoteUnquote(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
`quote(5)`,
`5`,
},
{
`quote(5+8)`,
`(5 + 8)`,
},
{
`quote(foobar)`,
`foobar`,
},
{
`quote(foobar+barfoo)`,
`(foobar + barfoo)`,
},
{
`let foobar=8;
quote(foobar)`,
`foobar`,
},
{
`let foobar=8;
quote(unquote(foobar))`,
`8`,
},
{
`quote(unquote(true))`,
`true`,
},
{
`quote(unquote(true==false))`,
`false`,
},
{
`quote(quote(123))`,
`quote(123)`,
},
{
`quote(unquote(quote(4+4)))`,
`(4 + 4)`,
},
}
for _, tt := range tests {
evaluated := testEval(tt.input)
quote, ok := evaluated.(*object.Quote)
if !ok {
t.Fatalf("expected *object.Quote. got=%T (%+v)", evaluated, evaluated)
}
if quote.Node == nil {
t.Fatalf("quote.Node is nil")
}
if quote.Node.String() != tt.expected {
t.Errorf("not equal. got=%q, want=%q",
quote.Node.String(), tt.expected)
}
}
}
monkey/evaluator/macro_expansion.go
package evaluator
import (
"monkey/ast"
"monkey/object"
)
func DefineMacros(program *ast.Program, env *object.Environment) {
definitions := []int{}
for i, statement := range program.Statements {
if isMacroDefinition(statement) {
addMacro(statement, env)
definitions = append(definitions, i)
}
}
for i := len(definitions) - 1; i >= 0; i-- {
definitionIndex := definitions[i]
program.Statements = append(program.Statements[:definitionIndex],
program.Statements[definitionIndex+1:]...)
}
}
func isMacroDefinition(node ast.Statement) bool {
letStatement, ok := node.(*ast.LetStatement)
if ok {
_, ok = letStatement.Value.(*ast.MacroLiteral)
}
return ok
}
func addMacro(stmt ast.Statement, env *object.Environment) {
letStatement, _ := stmt.(*ast.LetStatement)
macroLiteral, _ := letStatement.Value.(*ast.MacroLiteral)
macro := &object.Macro{
Parameters: macroLiteral.Parameters,
Env: env,
Body: macroLiteral.Body,
}
env.Set(letStatement.Name.String(), macro)
}
func ExpandMacros(program ast.Node, env *object.Environment) ast.Node {
return ast.Modify(program, func(node ast.Node) ast.Node {
if callExpression, ok := node.(*ast.CallExpression); ok {
if macro, ok := isMacroCall(callExpression, env); ok {
args := quoteArgs(callExpression)
evalEnv := extendMacroEnv(macro, args)
evaluated := Eval(macro.Body, evalEnv)
if quote, ok := evaluated.(*object.Quote); ok {
return quote.Node
}
panic("we only support returning AST-nodes from macros")
}
}
return node
})
}
func isMacroCall(exp *ast.CallExpression, env *object.Environment) (*object.Macro, bool) {
if identifier, ok := exp.Function.(*ast.Identifier); ok {
if obj, ok := env.Get(identifier.String()); ok {
if macro, ok := obj.(*object.Macro); ok {
return macro, true
}
}
}
return nil, false
}
func quoteArgs(exp *ast.CallExpression) []*object.Quote {
args := []*object.Quote{}
for _, a := range exp.Arguments {
args = append(args, &object.Quote{Node: a})
}
return args
}
func extendMacroEnv(macro *object.Macro, args []*object.Quote) *object.Environment {
extended := object.NewEnclosedEnvironment(macro.Env)
for paramIdx, param := range macro.Parameters {
extended.Set(param.String(), args[paramIdx])
}
return extended
}
monkey/evaluator/macro_expansion_test.go
package evaluator
import (
"monkey/ast"
"monkey/lexer"
"monkey/object"
"monkey/parser"
"testing"
)
func testParseProgram(input string) *ast.Program {
l := lexer.New(input)
p := parser.New(l)
return p.ParseProgram()
}
func TestDefineMacros(t *testing.T) {
input := `let number=1;
let function=fn(x,y){x+y};
let mymacro=macro(x,y){x+y;};`
env := object.NewEnvironment()
program := testParseProgram(input)
DefineMacros(program, env)
if len(program.Statements) != 2 {
t.Fatalf("Wrong number of statements. got=%d",
len(program.Statements))
}
_, ok := env.Get("number")
if ok {
t.Fatalf("number should not be defined")
}
_, ok = env.Get("function")
if ok {
t.Fatalf("function should not be defined")
}
obj, ok := env.Get("mymacro")
if !ok {
t.Fatalf("macro not in environment.")
}
macro, ok := obj.(*object.Macro)
if !ok {
t.Fatalf("macro not in environment.")
}
if len(macro.Parameters) != 2 {
t.Fatalf("Wrong number of macro parameters. got=%d",
len(macro.Parameters))
}
if macro.Parameters[0].String() != "x" {
t.Fatalf("parameter is not 'x'. got=%q",
macro.Parameters[0])
}
if macro.Parameters[1].String() != "y" {
t.Fatalf("parameter is not 'y'. got=%q",
macro.Parameters[1])
}
expectedBody := "{\n\t(x + y);\n}"
if macro.Body.String() != expectedBody {
t.Fatalf("body is not %q. got=%q",
expectedBody, macro.Body.String())
}
}
func TestExpandMacros(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
`let infixExpression=macro(){quote(1+2);};
infixExpression()`,
`(1 + 2)`,
},
{
`let reverse=macro(a,b){quote(unquote(b) - unquote(a));};
reverse(2+2,10-5);`,
`(10 - 5) - (2 + 2)`,
},
{
`let unless = macro(condition, consequence, alternative) {
quote(if(unquote(condition)) {
unquote(consequence);
} else {
unquote(alternative);
});
};
unless(10>5, puts("greater"), puts("not greater"));`,
`if(10 > 5) { puts("greater") } else { puts("not greater") }`,
},
}
for _, tt := range tests {
expected := testParseProgram(tt.expected)
program := testParseProgram(tt.input)
env := object.NewEnvironment()
DefineMacros(program, env)
expanded := ExpandMacros(program, env)
if expanded.String() != expected.String() {
t.Errorf("not equal. want=%q, got=%q", expected.String(), expanded.String())
}
}
}
monkey/repl/repl.go
package repl
import (
"bufio"
"fmt"
"io"
"monkey/evaluator"
"monkey/lexer"
"monkey/object"
"monkey/parser"
)
const PROMPT = ">>"
func Start(in io.Reader, out io.Writer) {
scanner := bufio.NewScanner(in)
env := object.NewEnvironment()
macroEnv := object.NewEnvironment()
for {
fmt.Fprintf(out, PROMPT)
scanned := scanner.Scan()
if !scanned {
return
}
line := scanner.Text()
l := lexer.New(line)
p := parser.New(l)
program := p.ParseProgram()
if len(p.Errors()) != 0 {
printParserErrors(out, p.Errors())
continue
}
evaluator.DefineMacros(program, macroEnv)
expended := evaluator.ExpandMacros(program, macroEnv)
if evaluated := evaluator.Eval(expended, env); evaluated != nil {
io.WriteString(out, evaluated.Inspect()+"\n")
}
}
}
const MONKEY_FACE = ` __,__
.--. .-" "-. .--.
/ .. \/ .-. .-. \/ .. \
| | '| / Y \ |' | |
| \ \ \ 0 | 0 / / / |
\ '- ,\.-"""""""-./, -' /
''-' /_ ^ ^ _\ '-''
| \._ _./ |
\ \ '~' / /
'._ '-=-' _.'
'-----'
`
func printParserErrors(out io.Writer, errors []string) {
io.WriteString(out, MONKEY_FACE)
io.WriteString(out, "Woops! We ran into some monkey business here!\n")
io.WriteString(out, "parser errors:\n")
for _, msg := range errors {
io.WriteString(out, "\t"+msg+"\n")
}
}
monkey/main.go
package main
import (
"fmt"
"monkey/repl"
"os"
"os/user"
)
func main() {
user, err := user.Current()
if err != nil {
panic(err)
}
fmt.Printf("Hello %s! This is the Monkey programming language!\n", user.Username)
fmt.Printf("Feel free to type in commands\n")
repl.Start(os.Stdin, os.Stdout)
}
monkey/run.txt
cd monkey
go mod init monkey
go test .\lexer -count=1
go test .\ast -count=1
go test .\parser -count=1
go test .\object -count=1
go test .\evaluator -count=1
go run main.go