[手写系列]Go手写db
ZiyiDB是一个简单的内存数据库实现,支持基本的SQL操作,包含create、insert、delete、select、update、drop。目前一期暂支持int类型以及字符类型数据,后续会支持更多数据结构以及能力。本项目基于https://github.com/eatonphil/gosql灵感进行开发。
- 项目Github地址:
https://github.com/ziyifast/ZiyiDB
请大家多多支持,也欢迎大家star⭐️和共同维护这个项目~
项目结构
// 项目创建
mkdir ZiyiDB
cd ZiyiDB/
go mod init ziyi.db.com
ZiyiDB/
├── cmd/
│ └── main.go # 主程序入口
├── internal/
│ ├── ast/
│ │ └── ast.go # 抽象语法树定义
│ ├── lexer/
│ │ ├── lexer.go # 词法分析器实现
│ │ └── token.go # 词法单元定义
│ ├── parser/
│ │ └── parser.go # 语法分析器实现
│ └── storage/
│ └── memory.go # 内存存储引擎实现
├── go.mod # Go模块定义
└── go.sum # 依赖版本锁定
原理介绍
流程图:
主要包含几大模块:
- cmd/main.go:
- 程序入口点
- 实现交互式命令行界面
- 处理用户输入
- 显示执行结果
- internal/ast/ast.go:
- 定义抽象语法树节点
- 定义 SQL 语句结构
- 定义表达式结构
- internal/lexer/token.go:
- 定义词法单元类型
- 定义 SQL 关键字
- 定义运算符和分隔符
- internal/lexer/lexer.go:
- 实现词法分析器
- 将输入文本转换为标记序列
- 处理标识符和字面量
- internal/parser/parser.go:
- 实现语法分析器
- 将标记序列转换为抽象语法树
- 处理各种 SQL 语句
- internal/storage/memory.go:
- 实现内存存储引擎
- 处理数据存储和检索
- 实现索引和约束
具体实现
模块一: 词法分析器 (Lexer)
词法分析器 (Lexer):SQL转token序列
①定义标记类型token.go
思路
新建ziyi-db/internal/lexer/token.go文件,完成词法分析器(Lexer)中的标记(Token)定义部分,用于将 SQL 语句分解成基本的语法单元。
定义词法单元以及关键字:
- 包含常见的SQL关键字,如:select、update等
- 包含符号关键字:=、>、<
- 包含字段类型:INT、字符型(TEXT)
- 包含标识符:INDENT,解析出来的SQL列名、表名
type TokenType string
const (
SELECT TokenType = "SELECT"
FROM TokenType = "FROM"
IDENT TokenType = "IDENT" // 标识符(如列名、表名)
INT_LIT TokenType = "INT" // 整数字面量
STRING TokenType = "STRING" // 字符串字面量
EQ TokenType = "=" // 等于
GT TokenType = ">" // 大于
LT TokenType = "<" // 小于
....
)
// Token 词法单元
// Type:标记的类型(如 SELECT、IDENT 等)
// Literal:标记的实际值(如具体的列名、数字等)
type Token struct {
Type TokenType // 标记类型
Literal string // 标记的实际值
}
示例:
SELECT id, name FROM users WHERE age > 18;
该SQL 语句会被下面的词法分析器lexer.go分解成以下标记序列:
{Type: SELECT, Literal: "SELECT"}
{Type: IDENT, Literal: "id"}
{Type: COMMA, Literal: ","}
{Type: IDENT, Literal: "name"}
{Type: FROM, Literal: "FROM"}
{Type: IDENT, Literal: "users"}
{Type: WHERE, Literal: "WHERE"}
{Type: IDENT, Literal: "age"}
{Type: GT, Literal: ">"}
{Type: INT_LIT, Literal: "18"}
{Type: SEMI, Literal: ";"}
解析后的标记随后会被传递给语法分析器(Parser)进行进一步处理,构建抽象语法树(AST)。
全部代码
// internal/lexer/token.go
package lexer
// TokenType 表示词法单元类型
type TokenType string
const (
// 特殊标记
EOF TokenType = "EOF" // 文件结束标记
ERROR TokenType = "ERROR" // 错误标记
// 关键字
SELECT TokenType = "SELECT"
FROM TokenType = "FROM"
WHERE TokenType = "WHERE"
CREATE TokenType = "CREATE"
TABLE TokenType = "TABLE"
INSERT TokenType = "INSERT"
INTO TokenType = "INTO"
VALUES TokenType = "VALUES"
UPDATE TokenType = "UPDATE"
SET TokenType = "SET"
DELETE TokenType = "DELETE"
DROP TokenType = "DROP"
PRIMARY TokenType = "PRIMARY"
KEY TokenType = "KEY"
INT TokenType = "INT"
TEXT TokenType = "TEXT"
LIKE TokenType = "LIKE"
// 标识符和字面量
IDENT TokenType = "IDENT" // 标识符(如列名、表名)
INT_LIT TokenType = "INT" // 整数字面量
STRING TokenType = "STRING" // 字符串字面量
// 运算符
EQ TokenType = "="
GT TokenType = ">"
LT TokenType = "<"
// 标识符
COMMA TokenType = ","
SEMI TokenType = ";"
LPAREN TokenType = "("
RPAREN TokenType = ")"
ASTERISK TokenType = "*"
)
// Token 词法单元
// Type:标记的类型(如 SELECT、IDENT 等)
// Literal:标记的实际值(如具体的列名、数字等)
type Token struct {
Type TokenType // 标记类型
Literal string // 标记的实际值
}
② 实现词法分析器lexer.go
思路
新建ziyi-db/internal/lexer/lexer.go文件,这是词法分析器(Lexer)的核心实现,负责将输入的 SQL 语句分解成标记(Token)序列。
词法分析器lexer.go:读取SQL到内存中并进行解析,将字符转换为对应关键字
示例:
SELECT id, name FROM users WHERE age > 18;
处理过程:
跳过空白字符
读取 "SELECT" 并识别为关键字
读取 "id" 并识别为标识符
读取 "," 并识别为分隔符
读取 "name" 并识别为标识符
读取 "FROM" 并识别为关键字
读取 "users" 并识别为标识符
读取 "WHERE" 并识别为关键字
读取 "age" 并识别为标识符
读取 ">" 并识别为运算符
读取 "18" 并识别为数字
读取 ";" 并识别为分隔符
这个词法分析器是 SQL 解析器的第一步,它将输入的 SQL 语句分解成标记序列,为后续的语法分析提供基础
该SQL 语句会被词法分析器分解成以下标记序列:
{Type: SELECT, Literal: "SELECT"}
{Type: IDENT, Literal: "id"}
{Type: COMMA, Literal: ","}
{Type: IDENT, Literal: "name"}
{Type: FROM, Literal: "FROM"}
{Type: IDENT, Literal: "users"}
{Type: WHERE, Literal: "WHERE"}
{Type: IDENT, Literal: "age"}
{Type: GT, Literal: ">"}
{Type: INT_LIT, Literal: "18"}
{Type: SEMI, Literal: ";"}
解析后的标记随后会被传递给语法分析器(Parser)进行进一步处理,构建抽象语法树(AST)。
全部代码
// internal/lexer/lexer.go
package lexer
import (
"bufio"
"bytes"
"io"
"strings"
"unicode"
)
// Lexer 词法分析器
// reader:使用 bufio.Reader 进行高效的字符读取
// ch:存储当前正在处理的字符
type Lexer struct {
reader *bufio.Reader // 用于读取输入
ch rune // 当前字符
}
// NewLexer 创建一个新的 词法分析器
// 初始化 reader 并读取第一个字符
func NewLexer(r io.Reader) *Lexer {
l := &Lexer{
reader: bufio.NewReader(r),
}
l.readChar()
return l
}
// 读取字符
func (l *Lexer) readChar() {
ch, _, err := l.reader.ReadRune()
if err != nil {
l.ch = 0 // 遇到错误或EOF时设置为0
} else {
l.ch = ch
}
}
// NextToken 获取下一个词法单元
// 识别并返回下一个标记
// 处理各种类型的标记:运算符、分隔符、标识符、数字、字符串等
func (l *Lexer) NextToken() Token {
var tok Token
// 跳过空白字符
l.skipWhitespace()
switch l.ch {
case '=':
tok = Token{Type: EQ, Literal: "="}
case '>':
tok = Token{Type: GT, Literal: ">"}
case '<':
tok = Token{Type: LT, Literal: "<"}
case ',':
tok = Token{Type: COMMA, Literal: ","}
case ';':
tok = Token{Type: SEMI, Literal: ";"}
case '(':
tok = Token{Type: LPAREN, Literal: "("}
case ')':
tok = Token{Type: RPAREN, Literal: ")"}
case '*':
tok = Token{Type: ASTERISK, Literal: "*"}
case '\'':
tok.Type = STRING
// 读取字符串字面量
tok.Literal = l.readString()
return tok
case 0:
tok = Token{Type: EOF, Literal: ""}
default:
if isLetter(l.ch) {
// 读取标识符(表名、列名等)
tok.Literal = l.readIdentifier()
// 将读取到的标识符转换为对应的标记类型(转换为对应tokenType)
tok.Type = l.lookupIdentifier(tok.Literal)
return tok
} else if isDigit(l.ch) {
tok.Type = INT_LIT
// 读取数字
tok.Literal = l.readNumber()
return tok
} else {
tok = Token{Type: ERROR, Literal: string(l.ch)}
}
}
l.readChar()
return tok
}
func (l *Lexer) skipWhitespace() {
for unicode.IsSpace(l.ch) {
l.readChar()
}
}
// 读取标识符,如:列名、表名
func (l *Lexer) readIdentifier() string {
var ident bytes.Buffer
for isLetter(l.ch) || isDigit(l.ch) {
ident.WriteRune(l.ch)
l.readChar()
}
return ident.String()
}
func (l *Lexer) readNumber() string {
var num bytes.Buffer
for isDigit(l.ch) {
num.WriteRune(l.ch)
l.readChar()
}
return num.String()
}
// 读取字符串字面量
func (l *Lexer) readString() string {
var str bytes.Buffer
l.readChar() // 跳过开始的引号
for l.ch != '\'' && l.ch != 0 {
str.WriteRune(l.ch)
l.readChar()
}
l.readChar() // 跳过结束的引号
return str.String()
}
func (l *Lexer) peekChar() rune {
ch, _, err := l.reader.ReadRune()
if err != nil {
return 0
}
l.reader.UnreadRune()
return ch
}
// lookupIdentifier 查找标识符类型
// 将标识符转换为对应的标记类型
// 识别 SQL 关键字
func (l *Lexer) lookupIdentifier(ident string) TokenType {
switch strings.ToUpper(ident) {
case "SELECT":
return SELECT
case "FROM":
return FROM
case "WHERE":
return WHERE
case "CREATE":
return CREATE
case "TABLE":
return TABLE
case "INSERT":
return INSERT
case "INTO":
return INTO
case "VALUES":
return VALUES
case "UPDATE":
return UPDATE
case "SET":
return SET
case "DELETE":
return DELETE
case "DROP":
return DROP
case "PRIMARY":
return PRIMARY
case "KEY":
return KEY
case "INT":
return INT
case "TEXT":
return TEXT
case "LIKE":
return LIKE
default:
return IDENT
}
}
// 判断字符是否为字母或下划线
func isLetter(ch rune) bool {
return unicode.IsLetter(ch) || ch == '_'
}
// 判断字符是否为数字
func isDigit(ch rune) bool {
return unicode.IsDigit(ch)
}
模块二:抽象语法树 (AST)
思路
抽象语法树用于表示 SQL 语句的语法结构。我们需要为每种 SQL 语句定义相应的节点类型。
我们新建internal/ast/ast.go。
ast.go构建不同SQL语句的结构,以及查询结果等。
这个 AST 定义文件是 SQL 解析器的核心部分,它:
- 定义了所有 SQL 语句的语法结构
- 提供了类型安全的方式来表示 SQL 语句
- 支持复杂的表达式和条件
- 便于后续的语义分析和执行
通过这个 AST,我们可以:
- 验证 SQL 语句的语法正确性
- 进行语义分析
- 生成执行计划
- 执行 SQL 语句
示例:
SELECT id, name FROM users WHERE age > 18;
交给语法分析器parser解析后的AST结构为:
SelectStatement
├── Fields
│ ├── Identifier{Value: "id"}
│ └── Identifier{Value: "name"}
├── TableName: "users"
└── Where
└── BinaryExpression
├── Left: Identifier{Value: "age"}
├── Operator: ">"
└── Right: IntegerLiteral{Value: "18"}
全部代码
package ast
import (
"cursor-db/internal/lexer"
"fmt"
)
// Node 表示AST中的节点
type Node interface {
TokenLiteral() string
}
// Statement 表示SQL语句
type Statement interface {
Node
statementNode()
}
// Expression 表示表达式
type Expression interface {
Node
expressionNode()
}
// Program 表示整个SQL程序
type Program struct {
Statements []Statement
}
// SelectStatement 表示SELECT语句
type SelectStatement struct {
Token lexer.Token
Fields []Expression
TableName string
Where Expression
}
func (ss *SelectStatement) statementNode() {}
func (ss *SelectStatement) TokenLiteral() string { return ss.Token.Literal }
// CreateTableStatement 表示CREATE TABLE语句
type CreateTableStatement struct {
Token lexer.Token
TableName string
Columns []ColumnDefinition
}
func (cts *CreateTableStatement) statementNode() {}
func (cts *CreateTableStatement) TokenLiteral() string { return cts.Token.Literal }
// InsertStatement 表示INSERT语句
type InsertStatement struct {
Token lexer.Token
TableName string
Values []Expression
}
func (is *InsertStatement) statementNode() {}
func (is *InsertStatement) TokenLiteral() string { return is.Token.Literal }
// ColumnDefinition 表示列定义
type ColumnDefinition struct {
Name string
Type string
Primary bool
Nullable bool
}
// Cell 表示数据单元格
type Cell struct {
Type CellType
IntValue int32
TextValue string
}
// CellType 表示单元格类型
type CellType int
const (
CellTypeInt CellType = iota
CellTypeText
)
// AsText 返回单元格的文本值
func (c *Cell) AsText() string {
switch c.Type {
case CellTypeInt:
s := fmt.Sprintf("%d", c.IntValue)
return s
case CellTypeText:
return c.TextValue
default:
return "NULL"
}
}
// AsInt 返回单元格的整数值
func (c *Cell) AsInt() int32 {
if c.Type == CellTypeInt {
return c.IntValue
}
return 0
}
// String 返回单元格的字符串表示
func (c Cell) String() string {
switch c.Type {
case CellTypeInt:
return fmt.Sprintf("%d", c.IntValue)
case CellTypeText:
return c.TextValue
default:
return "NULL"
}
}
// Results 表示查询结果
type Results struct {
Columns []ResultColumn
Rows [][]Cell
}
// ResultColumn 表示结果列
type ResultColumn struct {
Name string
Type string
}
// StarExpression 表示星号表达式,如:select * from users;
type StarExpression struct{}
func (se *StarExpression) expressionNode() {}
func (se *StarExpression) TokenLiteral() string { return "*" }
// LikeExpression 表示LIKE表达式, 如 LIKE '%b'
type LikeExpression struct {
Token lexer.Token
Left Expression
Pattern string
}
func (le *LikeExpression) expressionNode() {}
func (le *LikeExpression) TokenLiteral() string { return le.Token.Literal }
// BinaryExpression 表示二元表达式,如比较运算,大于小于比较等
type BinaryExpression struct {
Token lexer.Token
Left Expression
Operator string
Right Expression
}
func (be *BinaryExpression) expressionNode() {}
func (be *BinaryExpression) TokenLiteral() string { return be.Token.Literal }
// IntegerLiteral 表示整数字面量
type IntegerLiteral struct {
Token lexer.Token
Value string
}
func (il *IntegerLiteral) expressionNode() {}
func (il *IntegerLiteral) TokenLiteral() string { return il.Token.Literal }
// StringLiteral 表示字符串字面量
type StringLiteral struct {
Token lexer.Token
Value string
}
func (sl *StringLiteral) expressionNode() {}
func (sl *StringLiteral) TokenLiteral() string { return sl.Token.Literal }
// Identifier 表示标识符(如列名)
type Identifier struct {
Token lexer.Token
Value string
}
func (i *Identifier) expressionNode() {}
func (i *Identifier) TokenLiteral() string { return i.Token.Literal }
// UpdateStatement 表示UPDATE语句
type UpdateStatement struct {
Token lexer.Token
TableName string
Set []SetClause
Where Expression
}
func (us *UpdateStatement) statementNode() {}
func (us *UpdateStatement) TokenLiteral() string { return us.Token.Literal }
// SetClause 表示SET子句
type SetClause struct {
Column string
Value Expression
}
// DeleteStatement 表示DELETE语句
type DeleteStatement struct {
Token lexer.Token
TableName string
Where Expression
}
func (ds *DeleteStatement) statementNode() {}
func (ds *DeleteStatement) TokenLiteral() string { return ds.Token.Literal }
// DropTableStatement 表示DROP TABLE语句
type DropTableStatement struct {
Token lexer.Token
TableName string
}
func (ds *DropTableStatement) statementNode() {}
func (ds *DropTableStatement) TokenLiteral() string { return ds.Token.Literal }
模块三:语法分析器 (Parser)
思路
语法分析器负责将词法分析器生成的标记序列转换为抽象语法树。将token序列构建成ast。
SQL 解析器(Parser)的实现,负责将词法分析器(Lexer)产生的标记(Token)序列转换为抽象语法树(AST)。
语法分析器SQL 数据库系统的关键组件,负责:
- 验证 SQL 语句的语法正确性
- 构建抽象语法树
- 为后续的语义分析和执行提供基础
我们新建internal/parser/parser.go。
示例:
CREATE TABLE users (
id INT PRIMARY KEY,
name TEXT
);
解析过程:
1. 识别 CREATE 关键字
2. 解析 TABLE 关键字
3. 解析表名 "users"
4. 解析列定义:
列名 "id",类型 INT,主键
列名 "name",类型 TEXT
5. 生成 CREATE TABLE 语句的 AST
全部代码
package parser
import (
"fmt"
"ziyi.db.com/internal/ast"
"ziyi.db.com/internal/lexer"
)
// Parser 表示语法分析器
// 维护当前和下一个标记,实现向前查看(lookahead)
// 记录解析过程中的错误
type Parser struct {
l *lexer.Lexer // 词法分析器
curToken lexer.Token // 当前标记
peekToken lexer.Token // 下一个标记
errors []string // 错误信息
}
// NewParser 创建新的语法分析器
// 初始化解析器
// 预读两个标记
func NewParser(l *lexer.Lexer) *Parser {
p := &Parser{
l: l,
errors: []string{},
}
// 读取两个token,设置curToken和peekToken
p.nextToken()
p.nextToken()
return p
}
// nextToken 移动到下一个词法单元
func (p *Parser) nextToken() {
p.curToken = p.peekToken
p.peekToken = p.l.NextToken()
}
// ParseProgram 解析整个程序
// 解析整个 SQL 程序
// 循环解析每个语句直到结束
func (p *Parser) ParseProgram() (*ast.Program, error) {
program := &ast.Program{
Statements: []ast.Statement{},
}
for p.curToken.Type != lexer.EOF {
stmt, err := p.parseStatement()
if err != nil {
return nil, err
}
if stmt != nil {
program.Statements = append(program.Statements, stmt)
}
p.nextToken()
}
return program, nil
}
// parseStatement 解析语句
// 根据当前标记类型选择相应的解析方法
func (p *Parser) parseStatement() (ast.Statement, error) {
switch p.curToken.Type {
case lexer.CREATE:
return p.parseCreateTableStatement()
case lexer.INSERT:
return p.parseInsertStatement()
case lexer.SELECT:
return p.parseSelectStatement()
case lexer.UPDATE:
return p.parseUpdateStatement()
case lexer.DELETE:
return p.parseDeleteStatement()
case lexer.DROP:
return p.parseDropTableStatement()
case lexer.SEMI:
return nil, nil
default:
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Type)
}
}
// parseCreateTableStatement 解析CREATE TABLE语句
// 解析表名
// 解析列定义
// 处理主键约束
func (p *Parser) parseCreateTableStatement() (*ast.CreateTableStatement, error) {
stmt := &ast.CreateTableStatement{Token: p.curToken}
if !p.expectPeek(lexer.TABLE) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
if !p.expectPeek(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
stmt.TableName = p.curToken.Literal
if !p.expectPeek(lexer.LPAREN) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
// 解析列定义
for !p.peekTokenIs(lexer.RPAREN) {
p.nextToken()
if !p.curTokenIs(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
}
col := ast.ColumnDefinition{
Name: p.curToken.Literal,
}
if !p.expectPeek(lexer.INT) && !p.expectPeek(lexer.TEXT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
col.Type = string(p.curToken.Type)
if p.peekTokenIs(lexer.PRIMARY) {
p.nextToken()
if !p.expectPeek(lexer.KEY) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
col.Primary = true
}
stmt.Columns = append(stmt.Columns, col)
if p.peekTokenIs(lexer.COMMA) {
p.nextToken()
}
}
if !p.expectPeek(lexer.RPAREN) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
return stmt, nil
}
// parseInsertStatement 解析INSERT语句
// 解析表名
// 解析 VALUES 子句
// 解析插入的值
func (p *Parser) parseInsertStatement() (*ast.InsertStatement, error) {
stmt := &ast.InsertStatement{Token: p.curToken}
if !p.expectPeek(lexer.INTO) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
if !p.expectPeek(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
stmt.TableName = p.curToken.Literal
if !p.expectPeek(lexer.VALUES) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
if !p.expectPeek(lexer.LPAREN) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
// 解析值列表
for !p.peekTokenIs(lexer.RPAREN) {
p.nextToken()
expr, err := p.parseExpression()
if err != nil {
return nil, err
}
stmt.Values = append(stmt.Values, expr)
if p.peekTokenIs(lexer.COMMA) {
p.nextToken()
}
}
if !p.expectPeek(lexer.RPAREN) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
return stmt, nil
}
// parseSelectStatement 解析SELECT语句
// 解析选择列表
// 解析 FROM 子句
// 解析 WHERE 子句
func (p *Parser) parseSelectStatement() (*ast.SelectStatement, error) {
stmt := &ast.SelectStatement{Token: p.curToken}
// 解析选择列表
for !p.peekTokenIs(lexer.FROM) {
p.nextToken()
if p.curToken.Type == lexer.ASTERISK {
stmt.Fields = append(stmt.Fields, &ast.StarExpression{})
break
}
expr, err := p.parseExpression()
if err != nil {
return nil, err
}
stmt.Fields = append(stmt.Fields, expr)
if p.peekTokenIs(lexer.COMMA) {
p.nextToken()
}
}
if !p.expectPeek(lexer.FROM) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
if !p.expectPeek(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
stmt.TableName = p.curToken.Literal
// 解析WHERE子句
if p.peekTokenIs(lexer.WHERE) {
p.nextToken()
p.nextToken()
// 解析左操作数(列名)
if !p.curTokenIs(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
}
left := &ast.Identifier{
Token: p.curToken,
Value: p.curToken.Literal,
}
// 解析操作符
p.nextToken()
operator := p.curToken
// 处理LIKE操作符
if p.curTokenIs(lexer.LIKE) {
p.nextToken()
if !p.curTokenIs(lexer.STRING) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
}
// 移除字符串字面量的引号
pattern := p.curToken.Literal
if len(pattern) >= 2 && (pattern[0] == '\'' || pattern[0] == '"') {
pattern = pattern[1 : len(pattern)-1]
}
stmt.Where = &ast.LikeExpression{
Token: operator,
Left: left,
Pattern: pattern,
}
return stmt, nil
}
// 处理其他操作符
if !p.curTokenIs(lexer.EQ) && !p.curTokenIs(lexer.GT) && !p.curTokenIs(lexer.LT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", operator.Type)
}
// 解析右操作数
p.nextToken()
right, err := p.parseExpression()
if err != nil {
return nil, err
}
stmt.Where = &ast.BinaryExpression{
Token: operator,
Left: left,
Operator: operator.Literal,
Right: right,
}
}
return stmt, nil
}
// parseUpdateStatement 解析UPDATE语句
// 解析表名
// 解析 SET 子句
// 解析 WHERE 子句
func (p *Parser) parseUpdateStatement() (*ast.UpdateStatement, error) {
stmt := &ast.UpdateStatement{Token: p.curToken}
if !p.expectPeek(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
stmt.TableName = p.curToken.Literal
if !p.expectPeek(lexer.SET) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
// 解析SET子句
for {
p.nextToken()
if !p.curTokenIs(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
}
column := p.curToken.Literal
if !p.expectPeek(lexer.EQ) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
p.nextToken()
value, err := p.parseExpression()
if err != nil {
return nil, err
}
stmt.Set = append(stmt.Set, ast.SetClause{
Column: column,
Value: value,
})
if !p.peekTokenIs(lexer.COMMA) {
break
}
p.nextToken()
}
// 解析WHERE子句
if p.peekTokenIs(lexer.WHERE) {
p.nextToken()
p.nextToken()
// 解析左操作数(列名)
if !p.curTokenIs(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
}
left := &ast.Identifier{
Token: p.curToken,
Value: p.curToken.Literal,
}
// 解析操作符
p.nextToken()
operator := p.curToken
if !p.curTokenIs(lexer.EQ) && !p.curTokenIs(lexer.GT) && !p.curTokenIs(lexer.LT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", operator.Type)
}
// 解析右操作数
p.nextToken()
right, err := p.parseExpression()
if err != nil {
return nil, err
}
stmt.Where = &ast.BinaryExpression{
Token: operator,
Left: left,
Operator: operator.Literal,
Right: right,
}
}
return stmt, nil
}
// parseDeleteStatement 解析DELETE语句
// 解析表名
// 解析 WHERE 子句
func (p *Parser) parseDeleteStatement() (*ast.DeleteStatement, error) {
stmt := &ast.DeleteStatement{Token: p.curToken}
if !p.expectPeek(lexer.FROM) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
if !p.expectPeek(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
stmt.TableName = p.curToken.Literal
// 解析WHERE子句
if p.peekTokenIs(lexer.WHERE) {
p.nextToken()
p.nextToken()
// 解析左操作数(列名)
if !p.curTokenIs(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
}
left := &ast.Identifier{
Token: p.curToken,
Value: p.curToken.Literal,
}
// 解析操作符
p.nextToken()
operator := p.curToken
if !p.curTokenIs(lexer.EQ) && !p.curTokenIs(lexer.GT) && !p.curTokenIs(lexer.LT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", operator.Type)
}
// 解析右操作数
p.nextToken()
right, err := p.parseExpression()
if err != nil {
return nil, err
}
stmt.Where = &ast.BinaryExpression{
Token: operator,
Left: left,
Operator: operator.Literal,
Right: right,
}
}
return stmt, nil
}
// parseDropTableStatement 解析DROP TABLE语句
func (p *Parser) parseDropTableStatement() (*ast.DropTableStatement, error) {
stmt := &ast.DropTableStatement{Token: p.curToken}
if !p.expectPeek(lexer.TABLE) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
if !p.expectPeek(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.peekToken.Literal)
}
stmt.TableName = p.curToken.Literal
return stmt, nil
}
// parseExpression 解析表达式(字面量int、string类型,标识符列名、表名等)
// 解析各种类型的表达式
// 支持字面量、标识符等
func (p *Parser) parseExpression() (ast.Expression, error) {
switch p.curToken.Type {
case lexer.INT_LIT:
return &ast.IntegerLiteral{
Token: p.curToken,
Value: p.curToken.Literal,
}, nil
case lexer.STRING:
return &ast.StringLiteral{
Token: p.curToken,
Value: p.curToken.Literal,
}, nil
case lexer.IDENT:
return &ast.Identifier{
Token: p.curToken,
Value: p.curToken.Literal,
}, nil
default:
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Type)
}
}
// curTokenIs 检查当前token是否为指定类型
func (p *Parser) curTokenIs(t lexer.TokenType) bool {
return p.curToken.Type == t
}
// peekTokenIs 检查下一个token是否为指定类型
func (p *Parser) peekTokenIs(t lexer.TokenType) bool {
return p.peekToken.Type == t
}
// expectPeek 检查下一个词法单元是否为预期类型
func (p *Parser) expectPeek(t lexer.TokenType) bool {
if p.peekTokenIs(t) {
p.nextToken()
return true
}
return false
}
// parseWhereClause 解析WHERE子句
func (p *Parser) parseWhereClause() (ast.Expression, error) {
p.nextToken()
// 解析左操作数(列名)
if !p.curTokenIs(lexer.IDENT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
}
left := &ast.Identifier{
Token: p.curToken,
Value: p.curToken.Literal,
}
// 解析操作符
p.nextToken()
operator := p.curToken
// 处理LIKE操作符
if p.curTokenIs(lexer.LIKE) {
p.nextToken()
if !p.curTokenIs(lexer.STRING) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", p.curToken.Literal)
}
// 移除字符串字面量的引号
pattern := p.curToken.Literal
if len(pattern) >= 2 && (pattern[0] == '\'' || pattern[0] == '"') {
pattern = pattern[1 : len(pattern)-1]
}
return &ast.LikeExpression{
Token: operator,
Left: left,
Pattern: pattern,
}, nil
}
// 处理其他操作符
if !p.curTokenIs(lexer.EQ) && !p.curTokenIs(lexer.GT) && !p.curTokenIs(lexer.LT) {
return nil, fmt.Errorf("You have an error in your SQL syntax; check the manual that corresponds to your db server version for the right syntax to use near '%s'", operator.Type)
}
// 解析右操作数
p.nextToken()
right, err := p.parseExpression()
if err != nil {
return nil, err
}
return &ast.BinaryExpression{
Token: operator,
Left: left,
Operator: operator.Literal,
Right: right,
}, nil
}
模块四:存储引擎 (Storage)
思路
存储引擎负责实际的数据存储和检索操作,执行引擎中的数据操作CURD。
我们需要新建internal/storage/memory.go文件。
这是内存存储引擎的实现,负责处理 SQL 语句的实际执行和数据存储。
本期存储引擎实现了:
- 完整的数据操作(CRUD)
- 主键约束
- 索引支持
- 类型检查
- 条件评估
- 模式匹配
它是 SQL 数据库系统的核心组件,负责:
- 数据存储和管理
- 查询执行
- 数据完整性维护
- 性能优化(通过索引)
原理解析:
-- 创建表
CREATE TABLE users (
id INT PRIMARY KEY,
name TEXT
);
-- 插入数据
INSERT INTO users VALUES (1, 'Alice');
-- 查询数据
SELECT * FROM users WHERE name LIKE 'A%';
-- 更新数据
UPDATE users SET name = 'Bob' WHERE id = 1;
-- 删除数据
DELETE FROM users WHERE id = 1;
存储引擎会根据解析后的语法分析器,创建出对应的数据结构(如:在内存中),以及对外暴露对该数据的操作(CRUD)
全部代码
// internal/storage/memory.go
package storage
import (
"fmt"
"regexp"
"strconv"
"strings"
"ziyi.db.com/internal/ast"
)
// MemoryBackend 内存存储引擎,管理所有表
type MemoryBackend struct {
tables map[string]*Table
}
// Table 数据表,包含列定义、数据行和索引
type Table struct {
Name string
Columns []ast.ColumnDefinition
Rows [][]ast.Cell
Indexes map[string]*Index // 值到行索引的映射
}
// Index 索引,用于加速查询
type Index struct {
Column string
Values map[string][]int // 值到行索引的映射
}
// NewMemoryBackend 创建新的内存存储引擎
func NewMemoryBackend() *MemoryBackend {
return &MemoryBackend{
tables: make(map[string]*Table),
}
}
// CreateTable 创建表
// 验证表名唯一性
// 创建表结构
// 为主键列创建索引
func (b *MemoryBackend) CreateTable(stmt *ast.CreateTableStatement) error {
if _, exists := b.tables[stmt.TableName]; exists {
return fmt.Errorf("Table '%s' already exists", stmt.TableName)
}
table := &Table{
Name: stmt.TableName,
Columns: stmt.Columns,
Rows: make([][]ast.Cell, 0),
Indexes: make(map[string]*Index),
}
// 为主键创建索引
for _, col := range stmt.Columns {
if col.Primary {
table.Indexes[col.Name] = &Index{
Column: col.Name,
Values: make(map[string][]int),
}
}
}
b.tables[stmt.TableName] = table
return nil
}
// Insert 插入数据
// 验证表存在性
// 检查数据完整性
// 处理主键约束
// 维护索引
func (b *MemoryBackend) Insert(stmt *ast.InsertStatement) error {
table, exists := b.tables[stmt.TableName]
if !exists {
return fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
}
if len(stmt.Values) != len(table.Columns) {
return fmt.Errorf("Column count doesn't match value count at row 1")
}
// 转换值
row := make([]ast.Cell, len(stmt.Values))
for i, expr := range stmt.Values {
value, err := evaluateExpression(expr)
if err != nil {
return err
}
switch v := value.(type) {
case string:
if table.Columns[i].Type == "INT" {
// 尝试将字符串转换为整数
intVal, err := strconv.ParseInt(v, 10, 32)
if err != nil {
return fmt.Errorf("Incorrect integer value: '%s' for column '%s'", v, table.Columns[i].Name)
}
row[i] = ast.Cell{Type: ast.CellTypeInt, IntValue: int32(intVal)}
} else {
row[i] = ast.Cell{Type: ast.CellTypeText, TextValue: v}
}
case int32:
row[i] = ast.Cell{Type: ast.CellTypeInt, IntValue: v}
default:
return fmt.Errorf("Unsupported value type: %T for column '%s'", value, table.Columns[i].Name)
}
}
// 检查主键约束
for i, col := range table.Columns {
if col.Primary {
key := row[i].String()
if _, exists := table.Indexes[col.Name].Values[key]; exists {
return fmt.Errorf("Duplicate entry '%s' for key '%s'", key, col.Name)
}
}
}
// 插入数据
rowIndex := len(table.Rows)
table.Rows = append(table.Rows, row)
// 更新索引
for i, col := range table.Columns {
if col.Primary {
key := row[i].String()
table.Indexes[col.Name].Values[key] = append(table.Indexes[col.Name].Values[key], rowIndex)
}
}
return nil
}
// Select 查询数据
// 支持 SELECT * 和指定列
// 处理 WHERE 条件
// 返回查询结果
func (b *MemoryBackend) Select(stmt *ast.SelectStatement) (*ast.Results, error) {
table, exists := b.tables[stmt.TableName]
if !exists {
return nil, fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
}
results := &ast.Results{
Columns: make([]ast.ResultColumn, 0),
Rows: make([][]ast.Cell, 0),
}
// 处理选择列表
if len(stmt.Fields) == 1 && stmt.Fields[0].(*ast.StarExpression) != nil {
// SELECT *
for _, col := range table.Columns {
results.Columns = append(results.Columns, ast.ResultColumn{
Name: col.Name,
Type: col.Type,
})
}
} else {
// 处理指定的列
for _, expr := range stmt.Fields {
switch e := expr.(type) {
case *ast.Identifier:
// 查找列
found := false
for _, col := range table.Columns {
if col.Name == e.Value {
results.Columns = append(results.Columns, ast.ResultColumn{
Name: col.Name,
Type: col.Type,
})
found = true
break
}
}
if !found {
return nil, fmt.Errorf("Unknown column '%s' in 'field list'", e.Value)
}
default:
return nil, fmt.Errorf("Unsupported select expression type")
}
}
}
// 处理WHERE子句
for _, row := range table.Rows {
if stmt.Where != nil {
match, err := evaluateWhereCondition(stmt.Where, row, table.Columns)
if err != nil {
return nil, err
}
if !match {
continue
}
}
// 构建结果行
resultRow := make([]ast.Cell, len(results.Columns))
for j, col := range results.Columns {
// 查找列在原始行中的位置
for k, tableCol := range table.Columns {
if tableCol.Name == col.Name {
resultRow[j] = row[k]
break
}
}
}
results.Rows = append(results.Rows, resultRow)
}
return results, nil
}
// Update 执行UPDATE操作
// 验证表和列存在性
// 处理 WHERE 条件
// 更新符合条件的行
func (mb *MemoryBackend) Update(stmt *ast.UpdateStatement) error {
table, ok := mb.tables[stmt.TableName]
if !ok {
return fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
}
// 获取列索引
columnIndices := make(map[string]int)
for i, col := range table.Columns {
columnIndices[col.Name] = i
}
// 验证所有要更新的列是否存在
for _, set := range stmt.Set {
if _, ok := columnIndices[set.Column]; !ok {
return fmt.Errorf("Unknown column '%s' in 'field list'", set.Column)
}
}
// 更新符合条件的行
for i := range table.Rows {
if stmt.Where != nil {
// 评估WHERE条件
result, err := evaluateWhereCondition(stmt.Where, table.Rows[i], table.Columns)
if err != nil {
return err
}
if !result {
continue
}
}
// 更新行
for _, set := range stmt.Set {
colIndex := columnIndices[set.Column]
value, err := evaluateExpression(set.Value)
if err != nil {
return err
}
switch v := value.(type) {
case int32:
table.Rows[i][colIndex] = ast.Cell{Type: ast.CellTypeInt, IntValue: v}
case string:
table.Rows[i][colIndex] = ast.Cell{Type: ast.CellTypeText, TextValue: v}
default:
return fmt.Errorf("Unsupported value type: %T for column '%s'", value, set.Column)
}
}
}
return nil
}
// Delete 执行DELETE操作
// 验证表存在性
// 处理 WHERE 条件
// 删除符合条件的行
func (mb *MemoryBackend) Delete(stmt *ast.DeleteStatement) error {
table, ok := mb.tables[stmt.TableName]
if !ok {
return fmt.Errorf("Table '%s' doesn't exist", stmt.TableName)
}
// 找出要删除的行
rowsToDelete := make([]int, 0)
for i := range table.Rows {
if stmt.Where != nil {
// 评估WHERE条件
result, err := evaluateWhereCondition(stmt.Where, table.Rows[i], table.Columns)
if err != nil {
return err
}
if !result {
continue
}
}
rowsToDelete = append(rowsToDelete, i)
}
// 从后向前删除行,以避免索引变化
for i := len(rowsToDelete) - 1; i >= 0; i-- {
rowIndex := rowsToDelete[i]
table.Rows = append(table.Rows[:rowIndex], table.Rows[rowIndex+1:]...)
}
return nil
}
// DropTable 删除表
// 验证表是否存在
// 从存储引擎中删除表
func (mb *MemoryBackend) DropTable(stmt *ast.DropTableStatement) error {
if _, exists := mb.tables[stmt.TableName]; !exists {
return fmt.Errorf("Unknown table '%s'", stmt.TableName)
}
delete(mb.tables, stmt.TableName)
return nil
}
// evaluateExpression 评估表达式的值
// 计算表达式的值
// 处理不同类型的数据
func evaluateExpression(expr ast.Expression) (interface{}, error) {
switch e := expr.(type) {
case *ast.IntegerLiteral:
val, err := strconv.ParseInt(e.Value, 10, 32)
if err != nil {
return nil, fmt.Errorf("Incorrect integer value: '%s'", e.Value)
}
return int32(val), nil
case *ast.StringLiteral:
return e.Value, nil
case *ast.Identifier:
return nil, fmt.Errorf("Cannot evaluate identifier: '%s'", e.Value)
default:
return nil, fmt.Errorf("Unknown expression type: %T", expr)
}
}
// matchLikePattern 检查字符串是否匹配LIKE模式
func matchLikePattern(str, pattern string) bool {
// 将SQL LIKE模式转换为正则表达式
regexPattern := "^"
for i := 0; i < len(pattern); i++ {
switch pattern[i] {
case '%':
regexPattern += ".*"
case '_':
regexPattern += "."
case '\\':
if i+1 < len(pattern) {
regexPattern += "\\" + string(pattern[i+1])
i++
}
default:
// 转义正则表达式特殊字符
if strings.ContainsAny(string(pattern[i]), ".+*?^$()[]{}|") {
regexPattern += "\\" + string(pattern[i])
} else {
regexPattern += string(pattern[i])
}
}
}
regexPattern += "$"
// 编译正则表达式
re, err := regexp.Compile(regexPattern)
if err != nil {
return false
}
// 执行匹配
return re.MatchString(str)
}
// evaluateWhereCondition 评估WHERE条件
// 评估 WHERE 条件
// 支持比较运算符和 LIKE 操作符
func evaluateWhereCondition(expr ast.Expression, row []ast.Cell, columns []ast.ColumnDefinition) (bool, error) {
switch e := expr.(type) {
case *ast.BinaryExpression:
// 获取左操作数的值
leftValue, err := getColumnValue(e.Left, row, columns)
if err != nil {
return false, err
}
// 获取右操作数的值
rightValue, err := getColumnValue(e.Right, row, columns)
if err != nil {
return false, err
}
// 根据操作符比较值
switch e.Operator {
case "=":
return compareValues(leftValue, rightValue, "=")
case ">":
return compareValues(leftValue, rightValue, ">")
case "<":
return compareValues(leftValue, rightValue, "<")
default:
return false, fmt.Errorf("Unknown operator: '%s'", e.Operator)
}
case *ast.LikeExpression:
// 获取左操作数的值
leftValue, err := getColumnValue(e.Left, row, columns)
if err != nil {
return false, err
}
// 确保左操作数是字符串类型
strValue, ok := leftValue.(string)
if !ok {
return false, fmt.Errorf("LIKE operator requires string operand")
}
// 执行LIKE匹配
return matchLikePattern(strValue, e.Pattern), nil
default:
return false, fmt.Errorf("Unknown expression type: %T", expr)
}
}
// compareValues 比较两个值
func compareValues(left, right interface{}, operator string) (bool, error) {
switch l := left.(type) {
case int32:
if r, ok := right.(int32); ok {
switch operator {
case "=":
return l == r, nil
case ">":
return l > r, nil
case "<":
return l < r, nil
}
}
case string:
if r, ok := right.(string); ok {
switch operator {
case "=":
return l == r, nil
case ">":
return l > r, nil
case "<":
return l < r, nil
}
}
}
return false, fmt.Errorf("Cannot compare values of different types: %T and %T", left, right)
}
// getColumnValue 获取列的值
func getColumnValue(expr ast.Expression, row []ast.Cell, columns []ast.ColumnDefinition) (interface{}, error) {
switch e := expr.(type) {
case *ast.Identifier:
// 查找列索引
for i, col := range columns {
if col.Name == e.Value {
switch row[i].Type {
case ast.CellTypeInt:
return row[i].IntValue, nil
case ast.CellTypeText:
return row[i].TextValue, nil
default:
return nil, fmt.Errorf("Unknown cell type: %v", row[i].Type)
}
}
}
return nil, fmt.Errorf("Unknown column '%s' in 'where clause'", e.Value)
case *ast.IntegerLiteral:
val, err := strconv.ParseInt(e.Value, 10, 32)
if err != nil {
return nil, fmt.Errorf("Incorrect integer value: '%s'", e.Value)
}
return int32(val), nil
case *ast.StringLiteral:
return e.Value, nil
default:
return nil, fmt.Errorf("Unknown expression type: %T", expr)
}
}
//后续拓展新的存储引擎,如落地到文件...
模块五:REPL 交互界面
思路
最后,我们需要实现一个交互式的命令行界面,让用户可以输入 SQL 命令并查看结果。
这是 ZiyiDB 的主程序,实现了一个交互式的 SQL 命令行界面。
为了实现客户端可以上下翻找之前执行的命令以及cli客户端的美观,我们这里使用"github.com/c-bata/go-prompt"库
// 安装依赖
go get "github.com/c-bata/go-prompt"
我们需要新建cmd/main.go文件。
主要实现:
- 交互式命令行界面
- SQL 命令解析和执行
- 命令历史记录
- 查询结果格式化
- 错误处理和提示
全部代码
package main
import (
"fmt"
"github.com/c-bata/go-prompt"
"os"
"strings"
"ziyi.db.com/internal/ast"
"ziyi.db.com/internal/lexer"
"ziyi.db.com/internal/parser"
"ziyi.db.com/internal/storage"
)
var history []string // 存储命令历史
var backend *storage.MemoryBackend // 存储引擎实例
var historyIndex int // 当前历史记录索引
// 处理用户输入的命令
func executor(t string) {
t = strings.TrimSpace(t)
if t == "" {
return
}
// 添加到历史记录
history = append(history, t)
historyIndex = len(history) // 重置历史记录索引
// 处理退出命令
if strings.ToLower(t) == "exit" {
fmt.Println("Bye!")
os.Exit(0)
}
// 创建词法分析器
l := lexer.NewLexer(strings.NewReader(t))
// 创建语法分析器
p := parser.NewParser(l)
// 解析SQL语句
stmt, err := p.ParseProgram()
if err != nil {
fmt.Printf("Parse error: %v\n", err)
return
}
// 执行SQL语句
for _, statement := range stmt.Statements {
switch s := statement.(type) {
case *ast.CreateTableStatement:
if err := backend.CreateTable(s); err != nil {
fmt.Printf("Error: %v\n", err)
} else {
fmt.Println("Table created successfully")
}
case *ast.InsertStatement:
if err := backend.Insert(s); err != nil {
fmt.Printf("Error: %v\n", err)
} else {
fmt.Println("1 row inserted")
}
case *ast.SelectStatement:
results, err := backend.Select(s)
if err != nil {
fmt.Printf("Error: %v\n", err)
} else {
// 计算每列的最大宽度
colWidths := make([]int, len(results.Columns))
for i, col := range results.Columns {
colWidths[i] = len(col.Name)
}
for _, row := range results.Rows {
for i, cell := range row {
cellLen := len(cell.String())
if cellLen > colWidths[i] {
colWidths[i] = cellLen
}
}
}
// 打印表头
fmt.Print("+")
for _, width := range colWidths {
fmt.Print(strings.Repeat("-", width+2))
fmt.Print("+")
}
fmt.Println()
// 打印列名
fmt.Print("|")
for i, col := range results.Columns {
fmt.Printf(" %-*s |", colWidths[i], col.Name)
}
fmt.Println()
// 打印分隔线
fmt.Print("+")
for _, width := range colWidths {
fmt.Print(strings.Repeat("-", width+2))
fmt.Print("+")
}
fmt.Println()
// 打印数据行
for _, row := range results.Rows {
fmt.Print("|")
for i, cell := range row {
fmt.Printf(" %-*s |", colWidths[i], cell.String())
}
fmt.Println()
}
// 打印底部边框
fmt.Print("+")
for _, width := range colWidths {
fmt.Print(strings.Repeat("-", width+2))
fmt.Print("+")
}
fmt.Println()
// 打印行数统计
fmt.Printf("%d rows in set\n", len(results.Rows))
}
case *ast.UpdateStatement:
if err := backend.Update(s); err != nil {
fmt.Printf("Error: %v\n", err)
} else {
fmt.Println("Query OK, 1 row affected")
}
case *ast.DeleteStatement:
if err := backend.Delete(s); err != nil {
fmt.Printf("Error: %v\n", err)
} else {
fmt.Println("Query OK, 1 row affected")
}
case *ast.DropTableStatement:
if err := backend.DropTable(s); err != nil {
fmt.Printf("Error: %v\n", err)
} else {
fmt.Println("Table dropped successfully")
}
default:
fmt.Printf("Unsupported statement type: %T\n", s)
}
}
}
// 提供命令补全功能
func completer(d prompt.Document) []prompt.Suggest {
s := []prompt.Suggest{}
return prompt.FilterHasPrefix(s, d.GetWordBeforeCursor(), true)
}
func main() {
// 初始化存储引擎
backend = storage.NewMemoryBackend()
historyIndex = 0
fmt.Println("Welcome to ZiyiDB!")
fmt.Println("Type your SQL commands (type 'exit' to quit)")
p := prompt.New(
executor,
completer,
prompt.OptionTitle("ZiyiDB: A Simple SQL Database"),
prompt.OptionPrefix("ziyidb> "),
prompt.OptionHistory(history),
prompt.OptionLivePrefix(func() (string, bool) {
return "ziyidb> ", true
}),
//实现方向键上下翻阅历史命令
// 上键绑定
prompt.OptionAddKeyBind(prompt.KeyBind{
Key: prompt.Up,
Fn: func(buf *prompt.Buffer) {
if historyIndex > 0 {
historyIndex--
buf.DeleteBeforeCursor(len(buf.Text()))
buf.InsertText(history[historyIndex], false, true)
}
},
}),
// 下键绑定
prompt.OptionAddKeyBind(prompt.KeyBind{
Key: prompt.Down,
Fn: func(buf *prompt.Buffer) {
if historyIndex < len(history)-1 {
historyIndex++
buf.DeleteBeforeCursor(len(buf.Text()))
buf.InsertText(history[historyIndex], false, true)
} else if historyIndex == len(history)-1 {
historyIndex++
buf.DeleteBeforeCursor(len(buf.Text()))
}
},
}),
)
p.Run()
}
整体测试
编写完第一版后,现在我们来整体测试一下。
测试脚本
test_cast.sql:
-- 1. 创建表
CREATE TABLE users (id INT PRIMARY KEY,name TEXT ,age INT);
-- 2. 插入用户数据
INSERT INTO users VALUES (1, 'Alice', 20);
INSERT INTO users VALUES (2, 'Bob', 25);
INSERT INTO users VALUES (3, 'Charlie', 30);
INSERT INTO users VALUES (4, 'David', 35);
INSERT INTO users VALUES (5, 'Eve', 40);
-- 3. 测试主键冲突
INSERT INTO users VALUES (1, 'Tomas', 21);
-- 4. 基本查询测试
-- 4.1 查询所有数据
SELECT * FROM users;
-- 4.2 查询特定列
SELECT id, name FROM users;
-- 5. WHERE 子句测试
SELECT * FROM users WHERE age > 25;
SELECT * FROM users WHERE age < 30;
-- 6. LIKE 操作符测试
-- 6.1 基本模式匹配
SELECT * FROM users WHERE name LIKE 'A%'; -- 以 A 开头
SELECT * FROM users WHERE name LIKE '%e'; -- 以 e 结尾
-- 6.2 转义字符测试
INSERT INTO users VALUES (6, 'Bob%Smith', 45);
SELECT * FROM users WHERE name LIKE 'Bob\%Smith';
-- 7. 更新操作测试
-- 7.1 更新单个字段
UPDATE users SET age = 21 WHERE name = 'Alice';
-- 7.2 更新多个字段
UPDATE users SET name = 'Robert', age = 8 WHERE id = 2;
-- 8. 删除操作测试
DELETE FROM users WHERE age > 30;
-- 9. 清理测试数据
DROP TABLE users;
-- 10. 验证表已删除
SELECT * FROM users; -- 应该失败
todo::
1. 实现!= >= <=等运算符
2. 支持更多数据类型
3. 支持更多函数
4. 优化查询结果展示
5. 支持更多索引类型
6. 支持null值等
7. 支持数据落地本地文件
8. 支持事务操作等
运行效果
cd ZiyiDB
go run cmd/main.go
参考文章:
https://notes.eatonphil.com/database-basics.html