1.思路
- 将tcp和websocket的connect全部转化成一个interface,在框架层好管理
- 然后把connect.read接收的数据统一用一个函数处理,保证websocket和tcp请求走同一套业务代码
- connect.write同理
2.session基类
type Session interface {
Init()
GetId() uint64
GetDecoder() *codec.Decoder
GetSend() chan *PostMsg
GetRouter() *router.Router
Write(data []byte) (int, error)
Close() error
LocalAddr() net.Addr
RemoteAddr() net.Addr
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
GetRecordHeartbeatLog() bool
SetCustomData(uid int64, nickname string)
GetUid() int64
}
3.tcp封装
type TcpSession struct {
Id uint64
Conn net.Conn
NetProtocol string
//编码解码器
decoder *codec.Decoder
//注销通知channel
unregister chan uint64
Send chan *PostMsg
router *router.Router
//上次活跃时间,可以用于心跳机制
lastPackageTime int64
IsCheckHeartbeat bool
IsRecordHeartbeatLog bool
Uid int64
Nickname string
}
func (s *TcpSession) ReadPump() {
go func() {
defer func() {
s.unregister <- s.Id
}()
var err error
buf := make([]byte, 1024)
for {
if s.IsCheckHeartbeat {
err = s.SetReadDeadline(time.Now().Add(readLine))
if err != nil {
log.Errorf(tag, "session.SetReadDeadline, error: %+v", err)
return
}
}
var n int
n, err = s.Conn.Read(buf)
if err != nil {
err = errors.WithStack(err)
if errors.Is(err, io.EOF) {
return
}
log.Errorf(tag, "Conn.Read, error: %+v", err)
continue
}
// 更新上一个包的时间
s.lastPackageTime = time.Now().Unix()
//发生粘包,会得到多个packet
var pkgList []*packet.Packet
pkgList, err := s.decoder.Decode(buf[:n])
if err != nil {
log.Errorf(tag, "decoder.Decode, error: %+v", err)
continue
}
for _, pkg := range pkgList {
//tcp和websocket统一处理func
err = handlePacket(s, pkg)
if err != nil {
log.Errorf(tag, "handlePacket, error: %+v", err)
continue
}
}
}
}()
}
4.websocket封装
type WsSession struct {
Id uint64
Conn *websocket.Conn
NetProtocol string
decoder *codec.Decoder
unregister chan uint64
Send chan *PostMsg
router *router.Router
IsCheckHeartbeat bool
IsRecordHeartbeatLog bool
Uid int64
Nickname string
}
func (s *WsSession) ReadPump() {
go func() {
defer func() {
s.unregister <- s.Id
}()
for {
err := s.Conn.SetReadDeadline(time.Now().Add(readLine))
if err != nil {
log.Errorf(tag, "Conn.SetReadDeadline failed, err: %s", err)
return
}
t, message, err := s.Conn.ReadMessage()
err = errors.WithStack(err)
if err != nil {
if !websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Debugf(tag, "websocket.IsUnexpectedCloseError, error: %+v", err)
return
}
log.Errorf(tag, "Conn.ReadMessage, error: %+v", err)
}
log.Infof(tag, "ReadMessage,t: %d, message: %+v", t, message)
switch t {
case websocket.TextMessage:
case websocket.BinaryMessage:
var pkgList []*packet.Packet
pkgList, err := s.decoder.Decode(message)
if err != nil {
log.Errorf(tag, "decoder.Decode, error: %+v", err)
continue
}
for _, pkg := range pkgList {
//统一处理func
err = handlePacket(s, pkg)
if err != nil {
log.Errorf(tag, "handlePacket, error: %+v", err)
continue
}
}
}
}
}()
}
5.统一处理函数
func handlePacket(s Session, pkg *packet.Packet) (err error) {
switch pkg.Type {
case packet.Handshake:
if pkg.Length == 0 {
return
}
//握手逻辑
err = handshake(s, pkg.Data)
if err != nil {
return
}
case packet.Heartbeat:
if s.GetRecordHeartbeatLog() {
log.Debugf(tag, "receive session %d heartbeat", s.GetId())
}
case packet.Data:
if len(pkg.Data) == 0 {
log.Debugf(tag, "receive session %d is empty msg data", s.GetId())
return
}
//发到GateServer中的channel统一处理后,投递到mq中
s.GetSend() <- &PostMsg{
SessionId: s.GetId(),
MsgBytesList: [][]byte{pkg.Data},
}
}
return
}
6.客户端和服务器握手逻辑
- 客户端想向服务器发送握手请求,服务器返回服务器信息和路由字典
type handshakeReq struct {
Sys struct {
Version string `json:"version"`
ClientType string `json:"client_type"`
} `json:"sys"`
User struct{} `json:"user"`
}
type handshakeResp struct {
Code int `json:"code"`
Sys struct {
//路由字典
Dict map[string]uint16 `json:"dict"`
} `json:"sys"`
User struct{} `json:"user"`
}
func handshake(s Session, data []byte) (err error) {
var r handshakeReq
err = json.Unmarshal(data, &r)
err = errors.WithStack(err)
if err != nil {
return
}
if r.Sys.Version != "1.0.1" {
err = errors.New("version match failed")
return
}
ret := handshakeResp{
Code: 200,
Sys: struct {
Dict map[string]uint16 `json:"dict"`
}{
Dict: s.GetRouter().GetRoutes(),
},
User: struct{}{},
}
retBytes, err := json.Marshal(ret)
err = errors.WithStack(err)
if err != nil {
return
}
buf, err := s.GetDecoder().Encode(packet.HandshakeAck, retBytes)
if err != nil {
return
}
err = s.SetWriteDeadline(time.Now().Add(WriteLine))
_, err = s.Write(buf)
if err != nil {
return
}
return
}
源码地址,未经允许禁止转载