打造先进的内存KV数据库-5 TCP侦听

TCP侦听

作为支持集群的数据库,必定要与多个客户端交互信息,不可能让数据库与所有客户共享地址空间(虽然这样性能好),所以需要使用TCP协议进行交互数据,(UDP协议不可靠。。。弃用),C语言的TCP库其实还好,但是对于高并发和并行的处理不如Go,而且并发锁机制比较难写,所以使用Go写了服务器和客户端调用C的库,目前版本没有什么身份验证,之后会加上。

代码实现

//server.go
package main
// #cgo LDFLAGS: -L ./lib -lmonkeyS
// #include "./lib/core.h"
// #include <stdlib.h>
import "C"
import (
    "unsafe"
    _"fmt"
    "net"
    "strings"
)

func main() {
    str := []byte("monkey")
    str = append(str,0)
    C.CreateDB((*C.char)(unsafe.Pointer(&str[0])))  //创建基础数据库
    servicePort := ":1517"
    tcpAddr,err := net.ResolveTCPAddr("tcp4",servicePort)
    if err != nil {
        panic(err)
    }
    l,err := net.ListenTCP("tcp",tcpAddr)   //侦听TCP
    if err != nil {
        panic(err)
    }
    for{
        conn,err := l.Accept()
        if err != nil {
            panic(err)
        }
        go Handler(conn)
    }
}

func Handler(conn net.Conn) {

    str := []byte("monkey")                         //环境变量-当前数据库
    db := C.SwitchDB((*C.char)(unsafe.Pointer(&str[0])))
    for {               
        buff := []byte{}
        buf := make([]byte,1024)
        length,err := conn.Read(buf)
        total := uint32(0); //前4个字节保存消息长度
        for i := 0;i < 4;i++ {
            total <<= 8;
            total += uint32(buf[i]);
        }
        //fmt.Println("Message length:",total)
        buff = append(buff,buf[4:]...)
        total -= uint32(length)
        for total > 0 {
            length,err = conn.Read(buf)
            total -= uint32(length)
            buff = append(buff,buf...)
        }
        if err != nil {
            conn.Close()
            break
        }
        TranslateMessage(conn,&db,buff)                     //解析消息
    }

}

func TranslateMessage(conn net.Conn,db **C.Database,message []byte) {
    command := string(message)
    params := strings.Split(command," ")
    //fmt.Println(params)
    response := []byte{}
    if params[0] == "set" {
        r := C.Set(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))),(unsafe.Pointer(&([]byte(params[2]))[0])))
        for i := 0;;i++ {
            response = append(response,byte(r.msg[i]))
            if response[i] == 0 { break; }
        }

    }else if params[0] == "get" {
        r := C.Get(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
        // for i := 0;;i++ {
        //  response = append(response,byte(r.msg[i]))
        //  if response[i] == 0 { break; }
        // }
        if(int(r.code) == 0) {
            for i := 0;;i++ {
                response = append(response,byte(*(*C.char)(unsafe.Pointer((uintptr(r.pData)+uintptr(i))))))
                if response[i] == 0 { break; }
            }
        }else {
            // for i := 0;;i++ {
            // response = append(response,byte(r.msg[i]))
            // if response[i] == 0 { break; }
            // }
        }

    }else if params[0] == "delete" || params[0] == "remove" {
        r := C.Delete(&(*db).tIndex,(*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
        for i := 0;;i++ {
            response = append(response,byte(r.msg[i]))
            if response[i] == 0 { break; }
        }

    }else if params[0] == "createdb" {
        d := C.CreateDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
        if d != nil {
            *db = d
            response = []byte("Already exist,switched\n")
        }else {
            response = []byte("Created\n")
        }
    }else if params[0] == "switchdb" {
        d := C.SwitchDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
        if d != nil {
            *db = d
            response = []byte("ok\n")
        }else {
            response = []byte("fail\n")
        }
    }else if params[0] == "dropdb" {
        *db = C.DropDB((*C.char)(unsafe.Pointer((&([]byte(params[1]))[0]))))
    }else if strings.EqualFold("listdb",params[0]) {
        r := C.ListDB()
        for i := 0;i < 1024;i++ {
            b := byte(*(*C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(r))+uintptr(i))))
            response = append(response,b)
            if(b == 0){ break; }
        }
        C.free(unsafe.Pointer(r))
    }else {
        //fmt.Println("unkown command:",params[0])
    }
    total := len(response) + 4
    header := make([]byte,4)
    i := 0
    for total > 0 {
        header[3-i] = byte(total % 256)
        total /= 256
        i++
    }
    response = append(header,response...)
    conn.Write(response)
}
//Client.go
package main
import "net"
import "fmt"
func main() {
    tcpAddr, err := net.ResolveTCPAddr("tcp4", "127.0.0.1:1517")  
    if err != nil {
        panic(err)
    }
    conn, err := net.DialTCP("tcp", nil, tcpAddr)  
    if err != nil {
        panic(err)
    }

    for {
        buf1 := ""
        buf2 := ""
        buf3 := ""
        buf := ""
        fmt.Print("monkey>")
        fmt.Scanf("%s",&buf1)
        if buf1 == "set" {
            fmt.Scanf("%s",&buf2)
            fmt.Scanf("%s",&buf3)
            buf = buf1 + " " + buf2 + " " + buf3
        }else if buf1 == "get"{
            fmt.Scanf("%s",&buf2)
            buf = buf1 + " " + buf2
        }else if buf1 == "remove" || buf1 == "delete" {
            fmt.Scanf("%s",&buf2)
            buf = buf1 + " " + buf2
        }else if buf1 == "createdb"{
            fmt.Scanf("%s",&buf2)
            buf = buf1 + " " + buf2
        }else if buf1 == "switchdb"{
            fmt.Scanf("%s",&buf2)
            buf = buf1 + " " + buf2
        }else if buf1 == "dropdb"{
            fmt.Scanf("%s",&buf2)
            buf = buf1 + " " + buf2
        }else if buf1 == "listdb"{
            buf = buf1 + " "
        }else if buf1 == "exit"{
            fmt.Println("Bye!")
            break;
        }
        total := uint32(0)
        total = uint32(len(buf) + 4)
        header := make([]byte,4)
        i := 0
        for total > 0 {
            header[3-i] = byte(total % 256)
            total /= 256
            i++
        } 
        conn.Write(append(header,([]byte(buf))...))

        buff := []byte{}
        buff2 := make([]byte,1024)
        length,_ := conn.Read(buff2)
        total = uint32(0);  //前4个字节保存消息长度
        for i := 0;i < 4;i++ {
            total <<= 8;
            total += uint32(buff2[i]);
        }
        buff = append(buff,buff2[4:]...)
        total -= uint32(length)
        for total > 0 {
            length,_ = conn.Read(buff2)
            total -= uint32(length)
            buff = append(buff,buff2...)
        }
        for i := 0;i < 1024;i++ {
            if buff[i] == 0 { break; }
            fmt.Printf("%c",buff[i])
        }
        fmt.Print("\n")
    }
}

修正:上述代码存在严重问题:
发送1K以上数据会无法正确接收
改进代码如下:

//tcp.go
package tcp
import "net"
import "fmt"

func ok(bytes []byte) bool {
    return bytes[0] == 111 && bytes[1] == 107 && bytes[2] == 0;
}

func bytes4uint(bytes []byte) uint32 {
    total := uint32(0); 
    for i := 0;i < 4;i++ {
        total <<= 8;
        total += uint32(bytes[i]);
    }
    return total
}

func uint32bytes(n uint32) []byte {
    header := make([]byte,4)
    i := 0
    for n > 0 {
        header[3-i] = byte(n % 256)
        n /= 256
        i++
    }
    return header
}


type TCPSession struct {
    Conn *net.TCPConn
    ToSend chan interface{} //要发送的数据
    Received chan interface{}   //接受到的数据
    Closed bool //是否已经关闭
}

func (s *TCPSession) Init() {
    s.ToSend = make(chan interface{})
    s.Received = make(chan interface{})
    go s.Send()
    go s.Recv()
}

func (s *TCPSession) Send() {
    for {
        if s.Closed {
            return
        }
        buf0 := <- s.ToSend //取出要发送的数据
        buf := buf0.([]byte)

        _,err := s.Conn.Write(buf)  //发送掉   
        //fmt.Println("send,",buf)
        if err != nil {
            s.Closed = true
            return
        }
    }

}

func (s *TCPSession) Recv() {
    for {
        if s.Closed {
            return
        }
        buf := make([]byte,1024)
        _,err := s.Conn.Read(buf)
        if err != nil {
            s.Closed = true
            return
        }
        s.Received <- buf
        //fmt.Println("read,",buf)
        }

}

func (s *TCPSession) SendMessage(bytes []byte) {
    total := len(bytes) / 1024
    if len(bytes) % 1024 != 0 {
        total++
    }
    header := uint32bytes(uint32(total))    //计算条数
    s.ToSend <- header
    //fmt.Println(header)
    for i := 0;i < total-1;i++ {
        buf := bytes[0:1024]    //发送这一段
        bytes = bytes[1024:]
        s.ToSend <- buf
        continue
    }
    //发送最后一段
    if total == 0 {
        return
    }
    buf := bytes[0:]    //发送这一段
    s.ToSend <- buf
}

func (s *TCPSession) ReadMessage() []byte {
    buf0 := <- s.Received
    buf := buf0.([]byte)
    //fmt.Println(buf)
    total := bytes4uint(buf)
    var buff []byte
    if buf[4] != 0 {    //两份报表被合并
        buff = buf[4:]
        total--
    } else {
        buff = []byte{}     
    }

    for i := uint32(0);i < total;i++ {
        buf0 := <- s.Received
        buf := buf0.([]byte)
        buff = append(buff,buf...)
    }
    return buff
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值