golang 实战 简易聊天室服务(TCP)

golang 实战 简易聊天室服务(TCP)

首先,我们对整个服务进行简单划分: 通讯模组;消息储存模组;连接与连接之间的逻辑关系

通讯模组

通信流程

  1. 客户端->注册发信服务连接请求->服务端
  2. 客户端<-返回生成的客户端请求id+房间号<-服务端
  3. 客户端->注册收信服务连接请求->服务端
  4. 客户端<->并发收信发信<->服务端
  5. (客户端,服务端)任意一端程序停止,断开连接

报文组成(总报文=操作码+分报文)

  • 总报文=[操作码 1 byte]+[正文]
  • 注册收信报文=[客户端id 3 byte]+[房间号 2 byte]
  • 注册发信报文=[房间号 2 byte]
  • 消息响应报文=[消息id 1 byte]+[房间号 2 byte]+[消息来源客户端id 3 byte]
  • 消息通知报文=[消息数量 1 byte]
  • 消息发送报文=[正文]

消息储存模组

这里我们为了方便,使用的sqlite3作为数据库(驱动使用的是 github.com/mattn/go-sqlite3),当然也可以自行选者mysql,postgre, oracle或者你喜欢用的数据库。
如果你使用的也是sqlite,这里可能有个点需要注意。sqlite由于写入时不能同时写入,读取可以同时读的特性。所以使用时需要加锁,否则多线程执行非查询操作时会出现database is locked的错误

数据储存层代码(本演示全部数据库操作都经过此层)

func writeTodatabase(mod uint8, query string, args ...any) (any, error) {
	defer fmt.Printf("write to database finished,mod %d\n", mod)
	fmt.Println("start exec sql...")
	db, err := sql.Open("sqlite3", "data.db")
	if err == nil {
		defer db.Close()
		//如果是查询,加读锁。如果是写入或者修改,加写锁
		if mod == QUERY {
			dbmutex.RLock()
			defer dbmutex.RUnlock()
			return db.QueryContext(context.Background(), query, args...)
		} else if mod == QUERYROW {
			dbmutex.RLock()
			defer dbmutex.RUnlock()
			return db.QueryRowContext(context.Background(), query, args...), nil
		} else if mod == EXEC {
			dbmutex.Lock()
			defer dbmutex.Unlock()
			return db.ExecContext(context.Background(), query, args...)
		}
	}
	return nil, err

}

业务逻辑关系

type config struct {//整个通讯系统的运行信息
	Port int `yaml:"Port"`
}
type client struct {//用户连接实例
	roomid          uint32//此客户所属房间id
	last_reply_time time.Time
	last_verion     uint32//此客户端最新消息版本
	mutex           sync.RWMutex
	response_writer *response_Writer
}
type response_Writer struct {//用户收信实例,必须注册了收信实例才能接受信息
	con  net.Conn
	done chan struct{}
}
type room struct {//房间
	latestversion uint32
	mutex         sync.RWMutex
}
type message struct {//消息实例
	id             int//消息id,自增且唯一
	roomid, sendid uint32 //msgid,roomid,send userid
	content        string
}

room和client的关系
room实例本身就是一个房间最新消息版本号+读写锁。客户端读取信息只需要用room的锁上个读锁,获得最新消息后再解锁。写消息也一样(上写锁)。版本号由于向上增长,所以客户端版本号低于房间版本号就会拉取两个版本号之间的信息。
请添加图片描述

上代码

前置任务

const (
	TIMEOUTMAX = time.Duration(1000) * time.Millisecond
	MSG_SYN    = 10
	MSG_ACK    = 11
	MSG_CTR    = 12
)
const (
	DATA_OUT_OF_RANGE = "data out of range"
)
const (
	QUERY    = 90
	QUERYROW = 91
	EXEC     = 92
)

var (
	cnf      config
	errorlog = log.New(os.Stderr, "[error]", log.Ltime|log.Llongfile)//调试利器,快速定位到出错点
	debuglog = log.New(os.Stdout, "[debug]", log.Ltime|log.Llongfile)
	clientlist = make(map[uint32]*client)
	roomlist   = make(map[uint32]*room)
	dbmutex    sync.RWMutex //database mutex
)

func init() {
	flag.IntVar(&cnf.Port, "Port", 9000, "--cnf.Port xxx")
	flag.Parse()
}

主函数

func main() {
	if cnf.Port < 0 {
		errorlog.Println("Port", cnf.Port, " is invalid")
		os.Exit(1)
	}
	rand.Seed(time.Now().UnixNano()) //set rand seed
	listener, err := net.Listen("tcp", ":"+strconv.Itoa(cnf.Port))
	if err == nil {
		var con net.Conn
		for {
			con, err = listener.Accept()
			if err == nil {
				go readconse(con)
			} else {
				errorlog.Println(err.Error())
			}
		}
	} else {
		errorlog.Println(err.Error())
	}
}

连接处理

func readconse(con net.Conn) {
	defer con.Close()

	buffer := make([]byte, 1024)
	var (
		lang      int
		err       error
		resp      []byte
		id        uint32
		cli       *client
		is_accept bool = false//是否为收信连接
	)
	defer func() {
		if !is_accept {
			debuglog.Println("start deregister main", id, "process")
		} else {
			debuglog.Println("start deregister accept", id, "process")
		}
		if !is_accept && cli != nil && cli.response_writer != nil {
			debuglog.Println(id, "send end signal to accept process")
			cli.response_writer.done <- struct{}{}
		}
		if _, ok := clientlist[id]; ok {
			if !is_accept {
				debuglog.Println(id, "send server deregistered")
				delete(clientlist, id)
			} else {
				debuglog.Println(id, "accept server deregistered")
			}

		}
	}()
	for {
		lang, err = con.Read(buffer)
		if err == nil {
			switch buffer[0] {
			case MSG_SYN:
				if lang == 4 { //注册客户端收信连接
					is_accept = true
					id = uint32(buffer[1])*256*256 + uint32(buffer[2])*256 + uint32(buffer[3])
					var ok bool
					if cli, ok = clientlist[id]; ok {
						debuglog.Println("register client", id, "accept server")
						cli.response_writer = &response_Writer{con: con, done: make(chan struct{})}
						var msgarr []message
						tck := time.NewTicker(1 * time.Second)//每秒进行一波版本比较,可自行根据需求更改间隔时间
						for {
							select {
							case <-cli.response_writer.done:
								debuglog.Println(id, "accept done signal")
								return
							case <-tck.C:
								roomlist[cli.roomid].mutex.RLock()
								if cli.last_verion < roomlist[cli.roomid].latestversion {
									debuglog.Printf("client %v find new message", id)
									roomlist[cli.roomid].mutex.RUnlock()
									msgarr = getmessage(cli.last_verion, cli.roomid)
									debuglog.Printf("client %v find %d new messages", id, len(msgarr))
									if len(msgarr) > 0 {
										err = sendtocli(&msgarr, con)
										if err != nil {
											errorlog.Printf("cli accept server %v closed by client", id)
											return
										} else {
											cli.mutex.Lock()
											cli.last_verion = roomlist[cli.roomid].latestversion
											cli.mutex.Unlock()
											debuglog.Printf("client %v latest version update finished", id)
										}
									}
								} else {
									roomlist[cli.roomid].mutex.RUnlock()
								}
							}
						}
					}
				} else if lang == 3 { //注册客户端发信连接
					id = rand.Uint32() % (256*256*256 + 256*256 + 256)
					resp = make([]byte, 5)
					resp[0] = byte((id / (256 * 256)) % 256)
					resp[1] = byte((id / 256) % 256)
					resp[2] = byte(id % 256)
					copy(resp[3:5], buffer[1:3])
					_, err = con.Write(resp)
					if err != nil {
						debuglog.Println("connection closed by client")
						break
					}
					clientlist[id] = &client{roomid: uint32(buffer[1])*256 + uint32(buffer[2]), last_reply_time: time.Now()}
					cli = clientlist[id]
					if _, ok := roomlist[cli.roomid]; !ok {
						roomlist[cli.roomid] = &room{}
						debuglog.Println("create new room", cli.roomid)
					}
					debuglog.Printf("register client %v send server", id)
				}
			case MSG_ACK: //客户端向服务端发送信息
				if !writetomessage(buffer[1:lang], cli.roomid, id) {
					err = cli.write(MSG_CTR, []byte("send message failed unknown error"))
					if err != nil {
						errorlog.Println(err.Error())
						break
					}
				}
			case MSG_CTR: //切换房间等操作

			}
		} else {
			debuglog.Println("client", id, "closed connection")
			return
		}
	}
}

获取最新消息发往客户端

func getmessage(version, roomid uint32) []message {
	var msgarr = []message{}
	var msg *message
	debuglog.Println("room list read locked")
	roomlist[roomid].mutex.RLock()//挂房间读锁,函数结束时会自动解锁
	defer roomlist[roomid].mutex.RUnlock()
	defer fmt.Println("room list read unlocked")
	rwsany, err := writeTodatabase(QUERY, "select id,sendid,content from `"+strconv.FormatUint(uint64(roomid), 10)+"` where id>?", version)
	if err == nil {
		debuglog.Println("write to database finished")
		rws := rwsany.(*sql.Rows)
		for rws.Next() {
			debuglog.Println("read message...")
			msgarr = append(msgarr, message{roomid: roomid})
			msg = &msgarr[len(msgarr)-1]
			err = rws.Scan(&msg.id, &msg.sendid, &msg.content)
			if err != nil {
				debuglog.Printf("select id,sendid,content from `"+strconv.FormatUint(uint64(roomid), 10)+"` where id>%v", version)
				errorlog.Printf("read room %v msg failed,version %v, err %s", roomid, version, err.Error())
				return msgarr
			}
		}
	}
	return msgarr
}
func sendtocli(src *[]message, con net.Conn) error {
	if len(*src) == 0 {
		return nil
	}
	_, err := con.Write([]byte{MSG_SYN, byte(len(*src))})
	if err == nil {
		debuglog.Println("prepare send", len(*src))
		for _, ele := range *src {
			time.Sleep(10 * time.Millisecond) //sleep 10 milliseconds
			_, err = con.Write(ele.tobytes())
			if err != nil {
				return err
			} else {
				debuglog.Println("send message", string(ele.tobytes()[7:]))
			}
		}
	}
	return err
}

我们可能注意到发往客户端时有10毫秒休眠间隔,这个休眠间隔可以设置的小,但不能没有。计算机处理速度很快的现在,别看这里循环有那么几行代码,在计算机面前,嗖的一下就全没了。然后由于太快了,最后接收端缓冲区都还没读,我们就把东西全给别人塞进去了,最后几个独立的包可能被粘成一个了。
请添加图片描述
请添加图片描述

数据量大的时候又有另一个坑,接收端缓冲区设置的也许是1MB,但是我们发送端发的包是800kb,连发四包,最后接收端又读出问题,我们第二个包可能被接收端读出来就是身首异处。我们第二个包自己设置的控制层直接被读到第一个包末尾处了,第二个包我们解析的时候数据就是有问题的
请添加图片描述

消息结构体转byte数组

// message = id[1]+roomid[2]+sendcliid[3]+content
func (s *message) tobytes() []byte {
	ans := make([]byte, 7+len(s.content))
	ans[0] = MSG_ACK
	ans[1] = byte(s.id)
	ans[2] = byte((s.roomid / 256) % 256)
	ans[3] = byte(s.roomid % 256)
	ans[4] = byte((s.sendid / (256 * 256)) % 256)
	ans[5] = byte((s.sendid / 256) % 256)
	ans[6] = byte(s.sendid % 256)
	copy(ans[7:], []byte(s.content))
	return ans
}
func (s *client) write(code uint8, v []byte) error {
	ans := make([]byte, 1+len(v))
	ans[0] = code
	copy(ans[1:], v)
	if s.response_writer == nil {
		return fmt.Errorf("not set response writer")
	}
	_, err := s.response_writer.con.Write(ans)
	return err
}

客户端储存消息至房间数据库,并更新消息版本号

func writetomessage(message []byte, roomid, sendid uint32) bool {
	debuglog.Printf("client %v locked room %v", sendid, roomid)
	roomlist[roomid].mutex.Lock()
	defer roomlist[roomid].mutex.Unlock()
	defer debuglog.Printf("client %v unlocked room %v", sendid, roomid)
	_, err := writeTodatabase(EXEC, "insert into `"+strconv.FormatUint(uint64(roomid), 10)+"` (sendid,content)values(?,?)", sendid, string(message))
	if err == nil {
		debuglog.Printf("insert into `%v` (sendid,content)values(%v,'%v')", roomid, sendid, string(message))
		rowany, _ := writeTodatabase(QUERYROW, "select MAX(id) from `"+strconv.FormatUint(uint64(roomid), 10)+"`")//我们这里上的写锁,所以最大消息id就是刚才我们插入的那条消息
		var maxversion uint32
		row := rowany.(*sql.Row)
		if row.Scan(&maxversion) == nil {
			roomlist[roomid].latestversion = maxversion
			clientlist[sendid].last_verion = maxversion
		} else {
			errorlog.Println("get maxversion failed")
		}
		return true
	} else {
		errorlog.Println("insert data error", err.Error())
		_, err = writeTodatabase(EXEC, "create table `"+strconv.FormatUint(uint64(roomid), 10)+"` (id INTEGER PRIMARY KEY autoincrement NOT NULL,sendid INTEGER NOT NULL,content VARCHAR(1000))")
		if err == nil {
			_, err = writeTodatabase(EXEC, "insert into `"+strconv.FormatUint(uint64(roomid), 10)+"` (sendid,content)values(?,?)", sendid, string(message))
			if err == nil {
				rowany, _ := writeTodatabase(QUERYROW, "select MAX(id) from `"+strconv.FormatUint(uint64(roomid), 10)+"`")
				var maxversion uint32
				row := rowany.(*sql.Row)
				if row.Scan(&maxversion) == nil {
					roomlist[roomid].latestversion = maxversion
				} else {
					errorlog.Println("get maxversion failed")
				}
				return true
			} else {
				errorlog.Println(err.Error())
			}
		} else {
			errorlog.Printf("insert data to room %v failed,err=%s", roomid, err.Error())
		}

	}
	return false
}

实机演示

在这里插入图片描述
完整代码地址 https://github.com/oswaldoooo/mini_im

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值