study/Web-Gee/0day/gee/trie/trie.go
package trie
import "strings"
type Node struct {
Pattern string //完整路由
part string //当前匹配字段
child []*Node //子节点
}
// 输出Node的结构图
func (this *Node) Print(dep int) {
if dep >= 2 {
println(dep-1, strings.Repeat(" ", dep), this.part)
}
if this.Pattern != "" {
println("PATH:", this.Pattern)
return
}
for _, node := range this.child {
node.Print(dep + 1)
}
}
// 当路由为/*a/b/c转换为/*a
func clear(s string) []string {
split := strings.Split(s, "/")
cnt := 0
for _, val := range split {
if len(val) > 0 && val[0] == '*' {
break
}
cnt++
}
return split[:cnt]
}
// Insert
// 插入路由,返回转换后的路由
func (this *Node) Insert(pattern string) string {
parts := clear(pattern)
//strings.Join将[]string切片转换为一个string中间以sep间隔
join := strings.Join(parts, "/")
this.insert(join, parts, 0)
return join
}
func (this *Node) insert(pattern string, parts []string, dep int) {
if dep == len(parts) {
this.Pattern = pattern
return
}
for _, child := range this.child {
if child.part == parts[dep] {
this.insert(pattern, parts, dep+1)
return
}
}
child := &Node{
part: parts[dep],
child: []*Node{},
}
this.child = append(this.child, child)
child.insert(pattern, parts, dep+1)
}
func (this *Node) Check(parts []string) *Node {
return this.check(parts, 0)
}
// 查找一个匹配的路由
func (this *Node) check(parts []string, dep int) *Node {
if dep == len(parts) {
if this.Pattern != "" {
return this
}
return nil
}
for _, child := range this.child {
if child.part == parts[dep] || len(child.part) > 0 && child.part[0] == ':' {
check := child.check(parts, dep+1)
if check != nil {
return check
}
} else if len(child.part) > 0 && child.part[0] == '*' {
return this
}
}
return nil
}
首先是实现了一个Trie树,以及提供了两个函数,一个用于插入新的路由,一个则用来进行路由匹配,然后重新修改router
同时重写router.go
package gee
import (
"fmt"
"net/http"
"strings"
"study/Web-Gee/0day/gee/trie"
)
type router struct {
//存放路由对应的响应方法
handlers map[string]HandlerFunc
//用来存放GET,POST等请求的trie
trie map[string]*trie.Node
}
func newRouter() *router {
return &router{
handlers: make(map[string]HandlerFunc),
trie: make(map[string]*trie.Node),
}
}
func (r *router) addRouter(method string, pattern string, handlerFunc HandlerFunc) {
if _, ok := r.trie[method]; !ok {
r.trie[method] = &trie.Node{}
}
pattern = r.trie[method].Insert(pattern)
key := method + "-" + pattern
r.handlers[key] = handlerFunc
}
// 匹配路由,返回路由中携带的参数
func (r *router) getRouter(method, pattern string) (*trie.Node, map[string]string) {
//根据‘/‘对路由进行分割
parts := strings.Split(pattern, "/")
if _, ok := r.trie[method]; !ok {
return nil, nil
}
//查找对应请求
node := r.trie[method].Check(parts)
params := make(map[string]string)
if node != nil {
split := strings.Split(node.Pattern, "/")
for inx, val := range split {
if len(val) == 0 {
continue
}
switch val[0] {
case ':':
params[val[1:]] = parts[inx]
case '*':
params[val[1:]] = fmt.Sprint(parts[inx:])
break
}
}
}
return node, params
}
func (r *router) handler(ctx *Context) {
node, params := r.getRouter(ctx.method, ctx.path)
if node == nil {
ctx.String(http.StatusNotFound, "404 NOT FOUND: %s \n", ctx.path)
return
}
ctx.params = params
key := ctx.method + "-" + node.Pattern
handlerFunc := r.handlers[key]
handlerFunc(ctx)
}
如果匹配成功会返回匹配成功的路由以及提取出来的字段
根据匹配到的路由查找对应绑定的方法,完成请求的响应