主要依赖
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
websocket.go
package server
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/go-emix/utils"
"github.com/gorilla/websocket"
"log"
"net/http"
"sync"
)
type SocketSession struct {
c *websocket.Conn
Id string
}
func (r *SocketSession) SendString(data string) error {
return r.c.WriteMessage(websocket.TextMessage, []byte(data))
}
func (r *SocketSession) SendJson(data interface{}) error {
return r.c.WriteJSON(data)
}
type SocketSessionSet struct {
mux sync.RWMutex
set map[string]*SocketSession
}
var SocketSessionMiss = errors.New("socket session miss")
func (on *SocketSessionSet) add(s *SocketSession) {
on.mux.Lock()
on.set[s.Id] = s
on.mux.Unlock()
}
func (on *SocketSessionSet) Get(id string) (s *SocketSession, err error) {
on.mux.RLock()
defer on.mux.RUnlock()
s, ok := on.set[id]
if !ok {
err = SocketSessionMiss
}
return
}
func (on *SocketSessionSet) delete(id string) *SocketSession {
s, err := on.Get(id)
if err == nil {
on.mux.Lock()
delete(on.set, id)
on.mux.Unlock()
return s
}
return nil
}
func (on *SocketSessionSet) Counts() int {
on.mux.RLock()
defer on.mux.RUnlock()
return len(on.set)
}
type OnMsgReceive func(id string, message []byte)
type OnConnectOpen func(id string)
type OnConnectClose func(id string)
//gin config url path must be /../:id
func SocketEndpoint(receive OnMsgReceive, open OnConnectOpen, clos OnConnectClose) (hand gin.HandlerFunc, set *SocketSessionSet) {
var up = websocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
return true
}}
set = &SocketSessionSet{set: map[string]*SocketSession{}}
hand = func(c *gin.Context) {
var mid = struct {
Id string `uri:"id" binding:"required"`
}{}
e := c.ShouldBindUri(&mid)
utils.PanicError(e)
// 添加子协议可以用来传参,作为验证身份
header := c.GetHeader("Sec-WebSocket-Protocol")
up.Subprotocols = []string{header}
if header == "" {
c.JSON(http.StatusOK, "failed")
return
}
conn, e := up.Upgrade(c.Writer, c.Request, nil)
utils.PanicError(e)
session := &SocketSession{c: conn, Id: mid.Id}
set.add(session)
open(mid.Id)
conn.SetCloseHandler(func(code int, text string) error {
clos(mid.Id)
set.delete(mid.Id)
return nil
})
defer func() {
if e := recover(); e != nil {
log.Println(e)
}
err := conn.Close()
if err != nil {
log.Println("socket close err", err)
}
}()
for {
_, message, err := conn.ReadMessage()
if err != nil {
break
}
receive(mid.Id, message)
}
}
return
}
使用
var soks *SocketSessionSet
func ApiFunc(r gin.IRouter) {
var sochand gin.HandlerFunc
sochand, soks = SocketEndpoint(rece, open, clos)
r.GET("socket/:id", sochand)
}
func rece(id string, mes []byte) {
fmt.Println(id, " ", string(mes))
}
func open(id string) {
fmt.Println(id, " open")
}
func clos(id string) {
fmt.Println(id, " close")
}