- server
websocket/server/binder
/**
* @Author: Hhx06
* @Description:
* @File: binder
* @Version: 1.0.0
* @Date: 2020/10/29 14:55
*/
package server
import (
"errors"
"fmt"
"sync"
)
// eventConn wraps Conn with a specified event type.
type eventConn struct {
Event string
Conn *Conn
}
// binder is defined to store the relation of userID and eventConn
type binder struct {
mu sync.RWMutex
// map stores key: userID and value of related slice of eventConn
userID2EventConnMap map[string]*[]eventConn
// map stores key: connID and value: userID
connID2UserIDMap map[string]string
}
// Bind binds userID with eConn specified by event. It fails if the
// return error is not nil.
func (b *binder) Bind(userID, event string, conn *Conn) error {
if userID == "" {
return errors.New("userID can't be empty")
}
if event == "" {
return errors.New("event can't be empty")
}
if conn == nil {
return errors.New("conn can't be nil")
}
b.mu.Lock()
defer b.mu.Unlock()
// map the eConn if it isn't be put.
if eConns, ok := b.userID2EventConnMap[userID]; ok {
for i := range *eConns {
if (*eConns)[i].Conn == conn {
return nil
}
}
newEConns := append(*eConns, eventConn{event, conn})
b.userID2EventConnMap[userID] = &newEConns
} else {
b.userID2EventConnMap[userID] = &[]eventConn{{event, conn}}
}
b.connID2UserIDMap[conn.GetID()] = userID
return nil
}
// Unbind unbind and removes Conn if it's exist.
func (b *binder) Unbind(conn *Conn) error {
if conn == nil {
return errors.New("conn can't be empty")
}
b.mu.Lock()
defer b.mu.Unlock()
// query userID by connID
userID, ok := b.connID2UserIDMap[conn.GetID()]
if !ok {
return fmt.Errorf("can't find userID by connID: %s", conn.GetID())
}
if eConns, ok := b.userID2EventConnMap[userID]; ok {
for i := range *eConns {
if (*eConns)[i].Conn == conn {
newEConns := append((*eConns)[:i], (*eConns)[i+1:]...)
b.userID2EventConnMap[userID] = &newEConns
delete(b.connID2UserIDMap, conn.GetID())
// delete the key of userID when the length of the related
// eventConn slice is 0.
if len(newEConns) == 0 {
delete(b.userID2EventConnMap, userID)
}
return nil
}
}
return fmt.Errorf("can't find the conn of ID: %s", conn.GetID())
}
return fmt.Errorf("can't find the eventConns by userID: %s", userID)
}
// FindConn trys to find Conn by ID.
func (b *binder) FindConn(connID string) (*Conn, bool) {
if connID == "" {
return nil, false
}
userID, ok := b.connID2UserIDMap[connID]
// if userID been found by connID, then find the Conn using userID
if ok {
if eConns, ok := b.userID2EventConnMap[userID]; ok {
for i := range *eConns {
if (*eConns)[i].Conn.GetID() == connID {
return (*eConns)[i].Conn, true
}
}
}
return nil, false
}
// userID not found, iterate all the conns
for _, eConns := range b.userID2EventConnMap {
for i := range *eConns {
if (*eConns)[i].Conn.GetID() == connID {
return (*eConns)[i].Conn, true
}
}
}
return nil, false
}
// FilterConn searches the conns related to userID, and filtered by
// event. The userID can't be empty. The event will be ignored if it's empty.
// All the conns related to the userID will be returned if the event is empty.
func (b *binder) FilterConn(userID, event string) ([]*Conn, error) {
if userID == "" {
return nil, errors.New("userID can't be empty")
}
b.mu.RLock()
defer b.mu.RUnlock()
if eConns, ok := b.userID2EventConnMap[userID]; ok {
ecs := make([]*Conn, 0, len(*eConns))
for i := range *eConns {
if event == "" || (*eConns)[i].Event == event {
ecs = append(ecs, (*eConns)[i].Conn)
}
}
return ecs, nil
}
return []*Conn{}, nil
}
websocket/server/coon
/**
* @Author: Hhx06
* @Description:
* @File: conn
* @Version: 1.0.0
* @Date: 2020/10/29 14:56
*/
package server
import (
"errors"
"io"
"log"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
type Conn struct {
Conn *websocket.Conn
AfterReadFunc func(messageType int, r io.Reader)
BeforeCloseFunc func()
once sync.Once
id string
stopCh chan struct{}
}
// Write write p to the websocket connection. The error returned will always
// be nil if success.
func (c *Conn) Write(p []byte) (n int, err error) {
select {
case <-c.stopCh:
return 0, errors.New("Conn is closed, can't be written")
default:
err = c.Conn.WriteMessage(websocket.TextMessage, p)
if err != nil {
return 0, err
}
return len(p), nil
}
}
// GetID returns the id generated using UUID algorithm.
func (c *Conn) GetID() string {
c.once.Do(func() {
u := uuid.New()
c.id = u.String()
})
return c.id
}
// Listen listens for receive data from websocket connection. It blocks
// until websocket connection is closed.
func (c *Conn) Listen() {
c.Conn.SetCloseHandler(func(code int, text string) error {
if c.BeforeCloseFunc != nil {
c.BeforeCloseFunc()
}
if err := c.Close(); err != nil {
log.Println(err)
}
message := websocket.FormatCloseMessage(code, "")
c.Conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second))
return nil
})
// Keeps reading from Conn util get error.
ReadLoop:
for {
select {
case <-c.stopCh:
break ReadLoop
default:
messageType, r, err := c.Conn.NextReader()
if err != nil {
// TODO: handle read error maybe
break ReadLoop
}
if c.AfterReadFunc != nil {
c.AfterReadFunc(messageType, r)
}
}
}
}
// Close close the connection.
func (c *Conn) Close() error {
select {
case <-c.stopCh:
return errors.New("Conn already been closed")
default:
c.Conn.Close()
close(c.stopCh)
return nil
}
}
// NewConn wraps conn.
func NewConn(conn *websocket.Conn) *Conn {
return &Conn{
Conn: conn,
stopCh: make(chan struct{}),
}
}
websocket/server/handler
/**
* @Author: Hhx06
* @Description:
* @File: handler
* @Version: 1.0.0
* @Date: 2020/10/29 14:56
*/
package server
import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"
"github.com/gorilla/websocket"
)
// websocketHandler defines to handle websocket upgrade request.
type websocketHandler struct {
// upgrader is used to upgrade request.
upgrader *websocket.Upgrader
// binder stores relations about websocket connection and userID.
binder *binder
// calcUserIDFunc defines to calculate userID by token. The userID will
// be equal to token if this function is nil.
calcUserIDFunc func(token string) (userID string, ok bool)
}
// RegisterMessage defines message struct client send after connect
// to the server.
type RegisterMessage struct {
Token string
Event string
}
// First try to upgrade connection to websocket. If success, connection will
// be kept until client send close message or server drop them.
func (wh *websocketHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
wsConn, err := wh.upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer wsConn.Close()
// handle Websocket request
conn := NewConn(wsConn)
conn.AfterReadFunc = func(messageType int, r io.Reader) {
var rm RegisterMessage
decoder := json.NewDecoder(r)
if err := decoder.Decode(&rm); err != nil {
return
}
// calculate userID by token
userID := rm.Token
if wh.calcUserIDFunc != nil {
uID, ok := wh.calcUserIDFunc(rm.Token)
if !ok {
return
}
userID = uID
}
fmt.Println("先解绑 再去绑定bind")
// 先解绑 再去绑定bind
//wh.binder.Unbind(conn)
conns, err := wh.binder.FilterConn(userID, "onHhx")
if err != nil {
return
}
for i := range conns {
// unbind
if err := wh.binder.Unbind(conns[i]); err != nil {
log.Printf("conn unbind fail: %v", err)
continue
}
}
wh.binder.Bind(userID, "onHhx", conn)
}
conn.BeforeCloseFunc = func() {
// unbind
wh.binder.Unbind(conn)
}
conn.Listen()
}
// closeConns unbind conns filtered by userID and event and close them.
// The userID can't be empty, but event can be empty. The event will be ignored
// if empty.
func (wh *websocketHandler) closeConns(userID, event string) (int, error) {
conns, err := wh.binder.FilterConn(userID, event)
if err != nil {
return 0, err
}
cnt := 0
for i := range conns {
// unbind
if err := wh.binder.Unbind(conns[i]); err != nil {
log.Printf("conn unbind fail: %v", err)
continue
}
// close
if err := conns[i].Close(); err != nil {
log.Printf("conn close fail: %v", err)
continue
}
cnt++
}
fmt.Println("connClose")
return cnt, nil
}
// ErrRequestIllegal describes error when data of the request is unaccepted.
var ErrRequestIllegal = errors.New("request data illegal")
// pushHandler defines to handle push message request.
type pushHandler struct {
// authFunc defines to authorize request. The request will proceed only
// when it returns true.
authFunc func(r *http.Request) bool
binder *binder
}
// Authorize if needed. Then decode the request and push message to each
// realted websocket connection.
func (s *pushHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
// authorize
if s.authFunc != nil {
if ok := s.authFunc(r); !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
}
// read request
var pm PushMessage
decoder := json.NewDecoder(r.Body)
if err := decoder.Decode(&pm); err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(ErrRequestIllegal.Error()))
return
}
// validate the data
if pm.UserID == "" || pm.Event == "" || pm.Message == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(ErrRequestIllegal.Error()))
return
}
cnt, err := s.push(pm.UserID, pm.Event, pm.Message)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return
}
result := strings.NewReader(fmt.Sprintf("message sent to %d clients", cnt))
io.Copy(w, result)
}
func (s *pushHandler) push(userID, event, message string) (int, error) {
if userID == "" || event == "" || message == "" {
return 0, errors.New("parameters(userId, event, message) can't be empty")
}
// filter connections by userID and event, then push message
conns, err := s.binder.FilterConn(userID, event)
fmt.Println(conns)
if err != nil {
return 0, fmt.Errorf("filter conn fail: %v", err)
}
fmt.Println("ok,ok")
cnt := 0
for i := range conns {
fmt.Println(conns)
_, err := conns[i].Write([]byte(message))
if err != nil {
s.binder.Unbind(conns[i])
continue
}
cnt++
}
return cnt, nil
}
// PushMessage defines message struct send by client to push to each connected
// websocket client.
type PushMessage struct {
UserID string `json:"userId"`
Event string
Message string
}
websocket/server/server
/**
* @Author: Hhx06
* @Description:
* @File: server
* @Version: 1.0.0
* @Date: 2020/10/29 14:54
*/
package server
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/gorilla/websocket"
)
const (
serverDefaultWSPath = "/ws"
serverDefaultPushPath = "/push"
)
var defaultUpgrader = &websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(*http.Request) bool {
return true
},
}
// Server defines parameters for running websocket server.
type Server struct {
// Address for server to listen on
Addr string
// Path for websocket request, default "/ws".
WSPath string
// Path for push message, default "/push".
PushPath string
// Upgrader is for upgrade connection to websocket connection using
// "github.com/gorilla/websocket".
//
// If Upgrader is nil, default upgrader will be used. Default upgrader is
// set ReadBufferSize and WriteBufferSize to 1024, and CheckOrigin always
// returns true.
Upgrader *websocket.Upgrader
// Check token if it's valid and return userID. If token is valid, userID
// must be returned and ok should be true. Otherwise ok should be false.
AuthToken func(token string) (userID string, ok bool)
// Authorize push request. Message will be sent if it returns true,
// otherwise the request will be discarded. Default nil and push request
// will always be accepted.
PushAuth func(r *http.Request) bool
wh *websocketHandler
ph *pushHandler
}
// ListenAndServe listens on the TCP network address and handle websocket
// request.
func (s *Server) ListenAndServe() error {
b := &binder{
userID2EventConnMap: make(map[string]*[]eventConn),
connID2UserIDMap: make(map[string]string),
}
// websocket request handler
wh := websocketHandler{
upgrader: defaultUpgrader,
binder: b,
}
if s.Upgrader != nil {
wh.upgrader = s.Upgrader
}
if s.AuthToken != nil {
wh.calcUserIDFunc = s.AuthToken
}
s.wh = &wh
http.Handle(s.WSPath, s.wh)
// push request handler
ph := pushHandler{
binder: b,
}
if s.PushAuth != nil {
ph.authFunc = s.PushAuth
}
s.ph = &ph
http.Handle(s.PushPath, s.ph)
return http.ListenAndServe(s.Addr, nil)
}
// Push filters connections by userID and event, then write message
func (s *Server) Push(userID, event, message string) (int, error) {
return s.ph.push(userID, event, message)
}
// Drop find connections by userID and event, then close them. The userID can't
// be empty. The event is ignored if it's empty.
func (s *Server) Drop(userID, event string) (int, error) {
return s.wh.closeConns(userID, event)
}
// Check parameters of Server, returns error if fail.
func (s Server) check() error {
if !checkPath(s.WSPath) {
return fmt.Errorf("WSPath: %s not illegal", s.WSPath)
}
if !checkPath(s.PushPath) {
return fmt.Errorf("PushPath: %s not illegal", s.PushPath)
}
if s.WSPath == s.PushPath {
return errors.New("WSPath is equal to PushPath")
}
return nil
}
// NewServer creates a new Server.
func NewServer(addr string) *Server {
return &Server{
Addr: addr,
WSPath: serverDefaultWSPath,
PushPath: serverDefaultPushPath,
}
}
func checkPath(path string) bool {
if path != "" && !strings.HasPrefix(path, "/") {
return false
}
return true
}
websocket/main.go
/**
* @Author: Hhx06
* @Description:
* @File: main
* @Version: 1.0.0
* @Date: 2020/10/29 14:57
*/
package main
import (
"fmt"
"net/http"
server2 "websocket/server"
)
func main() {
server := server2.NewServer(":2345")
// Define websocket connect url, default "/ws"
server.WSPath = "/ws"
// Define push message url, default "/push"
server.PushPath = "/push"
// Set AuthToken func to authorize websocket connection, token is sent by
// client for register.
server.AuthToken = func(token string) (userID string, ok bool) {
//if token == "aaa" {
// return "jack", true
//}
fmt.Println("token")
return token,true
//return "", false
}
// Set PushAuth func to check push request. If the request is valid, returns
// true. Otherwise return false and request will be ignored.
server.PushAuth = func(r *http.Request) bool {
// TODO: check if request is valid
return true
}
fmt.Println("websocket")
// Run server
if err := server.ListenAndServe(); err != nil {
panic(err)
}
}
- push
push/handlers/push
/**
* @Author: Hhx06
* @Description:
* @File: Push
* @Version: 1.0.0
* @Date: 2020/10/29 15:02
*/
package Handlers
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"ws/Websocket"
)
func Push(c *gin.Context) {
uid, ok := c.GetQuery("uid")
if !ok {
c.JSON(http.StatusSeeOther, gin.H{
"code": 203,
"msg": "字段缺少",
})
return
}
//pushURL := "http://127.0.0.1:2345/push"
contentType := "application/json"
pm := Websocket.PushMessage{
UserID: uid,
Event: "onHhx",
Message: "{\"type\": \"onHhx\",\"message\": \"success\"}",
}
b, _ := json.Marshal(pm)
_, _ = http.DefaultClient.Post(pushURL, contentType, bytes.NewReader(b))
c.JSON(http.StatusOK, gin.H{
"code": 200,
"msg": "success",
})
fmt.Println("success")
return
// time.Sleep(time.Second)
//}
}
push/Middleware/CORS
package Middlewares
import (
"fmt"
"github.com/gin-gonic/gin"
)
func CORSMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
c.Writer.Header().Set("Access-Control-Expose-Headers", "Content-Length")
if c.Request.Method == "OPTIONS" {
fmt.Println("OPTIONS")
c.AbortWithStatus(200)
} else {
c.Next()
}
}
}
push/main
/**
* @Author: Hhx06
* @Description:
* @File: main
* @Version: 1.0.0
* @Date: 2020/10/29 14:36
*/
package main
import (
"github.com/gin-gonic/gin"
. "ws/Handlers"
Middlewares "ws/Middleware"
)
func main() {
router := gin.Default()
//gin.SetMode(gin.DebugMode)
router.Use(Middlewares.CORSMiddleware())
v1 := router.Group("/v1")
{
v1.GET("/push", Push)
}
router.Run(":9009")
}
转自https://github.com/alfred-zhong/wserver
做部分修改实现小程序支付异步回调websocket发送消息