【go语言之websocket】

写在前面

之前的文章都是介绍的是http的使用,这里主要介绍的是websocket,主要是解决长连接场景下的使用。这里概念不多说,网上很多,我们接下来看一下抓包的表现,已经用go语言如何去实现一个server端还有client

服务端

首先这里使用的websocket的包是第三方开源的。地址


func main() {
	ws.InitWsServer()
	select {}
}

这里是初始化websocket服务,然后使用select进行阻塞。然后看一下InitWsServer这个方法

func InitWsServer() {
    // 初始化msg
	msg.NewMsgProtocol(true)
	msg.GetMsgProtocol().Register(&pb.Ping{}, 1)
	msg.GetMsgProtocol().Register(&pb.Pong{}, 2)
	mux := http.NewServeMux()
	mux.HandleFunc("/connect", getConn)
	// HTTP服务
	server := http.Server{
		Addr:         "0.0.0.0:8888",
		ReadTimeout:  time.Duration(10) * time.Second,
		WriteTimeout: time.Duration(10) * time.Second,
		Handler:      mux,
	}
	fmt.Println("启动WS服务器成功 :", 8888)
	_ = server.ListenAndServe()
}

接下来先看看msg的

msg

msg主要做的是序列化和反序列。

var mgPrt *MsgProtocol

//协议
// id + proto.Message
type MsgProtocol struct {
	msgID        map[reflect.Type]uint16
	msgInfo      map[uint16]reflect.Type
	// 是否使用大端序
	useBigEndian bool
}

func NewMsgProtocol(useBigEndian bool) {
    // 初始全局化
	mgPrt = &MsgProtocol{
		msgID:        make(map[reflect.Type]uint16),
		msgInfo:      make(map[uint16]reflect.Type),
		useBigEndian: useBigEndian,
	}
}

func GetMsgProtocol() *MsgProtocol {
	if mgPrt == nil {
		panic("msg prt nil")
	}
	// 返回
	return mgPrt
}

func (m *MsgProtocol) Register(msg proto.Message, eventType uint16) {
    // 获取类型
	msgType := reflect.TypeOf(msg)
	if msgType == nil || msgType.Kind() != reflect.Ptr {
		panic(ErrMsgNotProto)
	}
	if len(m.msgInfo) >= math.MaxUint16 {
		panic(ErrProtocol)
	}
	m.msgInfo[eventType] = msgType
	m.msgID[msgType] = eventType
}

func (m *MsgProtocol) Marshal(msg interface{}) ([]byte, error) {
    // 判断是否存在
	msgType := reflect.TypeOf(msg)
	event, ok := m.msgID[msgType]
	if !ok {
		return nil, ErrNotRegister
	}
	// 使用proto进行序列化
	data, err := proto.Marshal(msg.(proto.Message))
	if err != nil {
		return nil, err
	}
	var (
		id      = make([]byte, 2)
		ptrData = make([]byte, 2+len(data))
	)
	// 判断使用大端序还是小端序
	if m.useBigEndian {
		binary.BigEndian.PutUint16(id, event)
	} else {
		binary.LittleEndian.PutUint16(id, event)
	}
	copy(ptrData[:2], id)
	copy(ptrData[2:], data)
	return ptrData, nil
}

func (m *MsgProtocol) Unmarshal(msg []byte) (interface{}, error) {
	if len(msg) < 2 {
		return nil, ErrMsgShort
	}
	// 先获取到id
	var id uint16
	if m.useBigEndian {
		id = binary.BigEndian.Uint16(msg[:2])
	} else {
		id = binary.LittleEndian.Uint16(msg[:2])
	}
	// 判断是使用什么类型 什么数据类型
	msgType, ok := m.msgInfo[id]
	if !ok {
		return nil, ErrNotRegister
	}
	var data = reflect.New(msgType.Elem()).Interface()

    // 进行反序列化
	err := proto.Unmarshal(msg[2:], data.(proto.Message))
	if err != nil {
		return nil, err
	}
	return data, nil
}

然后看一下 监听的getConn这个方法。

getConn

这个是在监听的时候,触发的方法。在这里会从http升级到websocket。

func getConn(res http.ResponseWriter, req *http.Request) {
	var (
		err    error
		wsConn *websocket.Conn
	)
	// 判断是否升级到websocket
	if wsConn, err = wsUpgrader.Upgrade(res, req, nil); err != nil {
		return
	}
	// 处理websocket的逻辑 初始化变量

 
    // 初始化websocket连接 当然这个是封装过的
	ws := NewWsConnection(wsConn)
	ws.SetIp(ClientIP(req))
	ws.SetUid(uint32(time.Now().Unix()))
	// 处理websocket逻辑
	wsHandle(ws)
}

Upgrade

这个是对连接进行升级,判断是否符合连接升级的需要

func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
	const badHandshake = "websocket: the client is not using the websocket protocol: "
    // 判断请求的header头中是否有Connection
	if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
		return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
	}
    // 判断请求头中是否有Upgrade
	if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
		return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
	}
    // 判断是否是GET请求
	if r.Method != http.MethodGet {
		return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
	}

	if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
		return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
	}

	if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
		return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
	}
  
    // 判断跨域
	checkOrigin := u.CheckOrigin
	if checkOrigin == nil {
		checkOrigin = checkSameOrigin
	}
	if !checkOrigin(r) {
		return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
	}
    // 判断是否有Sec-Websocket-Key
	challengeKey := r.Header.Get("Sec-Websocket-Key")
	if !isValidChallengeKey(challengeKey) {
		return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
	}
    // 判断是否更换协议
	subprotocol := u.selectSubprotocol(r, responseHeader)

	// Negotiate PMCE
	var compress bool
	if u.EnableCompression {
		for _, ext := range parseExtensions(r.Header) {
			if ext[""] != "permessage-deflate" {
				continue
			}
			compress = true
			break
		}
	}
    // 判断是否实现 http.Hijacker
	h, ok := w.(http.Hijacker)
	if !ok {
		return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
	}
	// 获取原始的连接
	var brw *bufio.ReadWriter
	netConn, brw, err := h.Hijack()
	if err != nil {
		return u.returnError(w, r, http.StatusInternalServerError, err.Error())
	}
    // 判断
	if brw.Reader.Buffered() > 0 {
		netConn.Close()
		return nil, errors.New("websocket: client sent data before handshake is complete")
	}

	var br *bufio.Reader
	if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
		// Reuse hijacked buffered reader as connection reader.
		br = brw.Reader
	}

	buf := bufioWriterBuffer(netConn, brw.Writer)

	var writeBuf []byte
	if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
		// Reuse hijacked write buffer as connection buffer.
		writeBuf = buf
	}
    // 实例化连接
	c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
	c.subprotocol = subprotocol

	if compress {
		c.newCompressionWriter = compressNoContextTakeover
		c.newDecompressionReader = decompressNoContextTakeover
	}

	// Use larger of hijacked buffer and connection write buffer for header.
	p := buf
	if len(c.writeBuf) > len(p) {
		p = c.writeBuf
	}
	p = p[:0]
    // 写回响应
	p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
	p = append(p, computeAcceptKey(challengeKey)...)
	p = append(p, "\r\n"...)
	if c.subprotocol != "" {
		p = append(p, "Sec-WebSocket-Protocol: "...)
		p = append(p, c.subprotocol...)
		p = append(p, "\r\n"...)
	}
	if compress {
		p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
	}
	for k, vs := range responseHeader {
		if k == "Sec-Websocket-Protocol" {
			continue
		}
		for _, v := range vs {
			p = append(p, k...)
			p = append(p, ": "...)
			for i := 0; i < len(v); i++ {
				b := v[i]
				if b <= 31 {
					// prevent response splitting.
					b = ' '
				}
				p = append(p, b)
			}
			p = append(p, "\r\n"...)
		}
	}
	p = append(p, "\r\n"...)

	// 重置超时时间
	netConn.SetDeadline(time.Time{})

	if u.HandshakeTimeout > 0 {
		netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
	}
	// 写回响应
	if _, err = netConn.Write(p); err != nil {
		netConn.Close()
		return nil, err
	}
	// 重新清空写超时时间
	if u.HandshakeTimeout > 0 {
		netConn.SetWriteDeadline(time.Time{})
	}

	return c, nil
}

这里的主要的逻辑就是通过Hijack拿到真是的连接,然后想websocket的握手的协议写回去。
然后这个Hijack看一下官方的实现.这个Hijack具体的实现是

// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
// and a Hijacker.
func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
	if w.handlerDone.isSet() {
		panic("net/http: Hijack called after ServeHTTP finished")
	}
	if w.wroteHeader {
		w.cw.flush()
	}

	c := w.conn
	c.mu.Lock()
	defer c.mu.Unlock()

	// Release the bufioWriter that writes to the chunk writer, it is not
	// used after a connection has been hijacked.
	rwc, buf, err = c.hijackLocked()
	if err == nil {
		putBufioWriter(w.w)
		w.w = nil
	}
	return rwc, buf, err
}
// c.mu must be held.
func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
	if c.hijackedv {
		return nil, nil, ErrHijacked
	}
	c.r.abortPendingRead()

	c.hijackedv = true
	// 底层的连接
	rwc = c.rwc
	rwc.SetDeadline(time.Time{})
    // 实例化reader和write
	buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc))
	if c.r.hasByte {
		if _, err := c.bufr.Peek(c.bufr.Buffered() + 1); err != nil {
			return nil, nil, fmt.Errorf("unexpected Peek failure reading buffered byte: %v", err)
		}
	}
	c.setState(rwc, StateHijacked, runHooks)
	return
}

可以看出来Hijack返回的是底层的tcp连接,已经对应的writeBuffer和readBuffer。所以当需要自定义协议的时候,这个都是合适的方法。

 

NewWsConnection

这个是封装websocket连接。

func NewWsConnection(conn *websocket.Conn) *WsConnection {
	ws := &WsConnection{}
	ws.ws = conn
	// 读写chan
	ws.readChan = make(chan interface{}, 10)
	ws.writeChan = make(chan interface{}, 10)
	ws.closeChan = make(chan bool)
	ws.isOpen = true
	// 生成连接id
	ws.connId = uuid.NewV5(uuid.Must(uuid.NewV4(), nil), "ws").String()
	// 分别开启一个协程进行读和写
	go ws.read()
	go ws.send()
	return ws
}

然后看一下send方法

func (w *WsConnection) send() {
	var (
		message interface{}
	)
	for {
		select {
		case message = <-w.writeChan:
		    // 进行序列化
			data, err := msg.GetMsgProtocol().Marshal(message)
			if err != nil {
				w.close()
				return
			}
			// 实例化write
			writer, err := w.ws.NextWriter(websocket.BinaryMessage)
			if err != nil {
				w.close()
				return
			}
			// 写入数据
			_, _ = writer.Write(data)
			// 关闭write
			_ = writer.Close()
		case <-w.closeChan:
			return
		}
	}
}
NextWriter

NextWriter 还是调用beginMessage进行实例化

func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
	var mw messageWriter
	if err := c.beginMessage(&mw, messageType); err != nil {
		return nil, err
	}
	c.writer = &mw
	// 这里可以忽略
	if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) {
		w := c.newCompressionWriter(c.writer, c.compressionLevel)
		mw.compress = true
		c.writer = w
	}
	return c.writer, nil
}
// beginMessage prepares a connection and message writer for a new message.
func (c *Conn) beginMessage(mw *messageWriter, messageType int) error {
	// Close previous writer if not already closed by the application. It's
	// probably better to return an error in this situation, but we cannot
	// change this without breaking existing applications.
	if c.writer != nil {
		c.writer.Close()
		c.writer = nil
	}

	if !isControl(messageType) && !isData(messageType) {
		return errBadWriteOpCode
	}

	c.writeErrMu.Lock()
	err := c.writeErr
	c.writeErrMu.Unlock()
	if err != nil {
		return err
	}

	mw.c = c
	// 设置frameType
	mw.frameType = messageType
	// maxFrameHeaderSize         = 2 + 8 + 4 // Fixed header + length + mask
	mw.pos = maxFrameHeaderSize

	if c.writeBuf == nil {
		wpd, ok := c.writePool.Get().(writePoolData)
		if ok {
			c.writeBuf = wpd.buf
		} else {
			c.writeBuf = make([]byte, c.writeBufSize)
		}
	}
	return nil
}

可以看出来这个方法主要是设置frameType 这里的是BinaryMessage。然后需要注意的是这里的pos并不是0,而是maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask。
也就是2个字节的固定头部,8个字节的报文长度,加上4个字节的掩码。
然后就是初始化writeBuf,从前面的可以看出这个初始化长度是defaultReadBufferSize = 4096字节。

write

然后就是write这个方法。其实这个也没有什么好理解的就是把p这个里面的数据拷贝到writeBuf中,去并且移动pos。可以看出来write只是把数据储存到了writeBuf。

func (w *messageWriter) Write(p []byte) (int, error) {
	if w.err != nil {
		return 0, w.err
	}
  
	if len(p) > 2*len(w.c.writeBuf) && w.c.isServer {
		// Don't buffer large messages.
		err := w.flushFrame(false, p)
		if err != nil {
			return 0, err
		}
		return len(p), nil
	}
    // 对移动p中的数据到writeBuf
	nn := len(p)
	for len(p) > 0 {
		n, err := w.ncopy(len(p))
		if err != nil {
			return 0, err
		}
		copy(w.c.writeBuf[w.pos:], p[:n])
		w.pos += n
		p = p[n:]
	}
	return nn, nil
}
func (w *messageWriter) Close() error {
	if w.err != nil {
		return w.err
	}
	return w.flushFrame(true, nil)
}
// flushFrame writes buffered data and extra as a frame to the network. The
// final argument indicates that this is the last frame in the message.
func (w *messageWriter) flushFrame(final bool, extra []byte) error {
	c := w.c
	// 消息的长度 减掉maxFrameHeaderSize
	length := w.pos - maxFrameHeaderSize + len(extra)

	// Check for invalid control frames.
	if isControl(w.frameType) &&
		(!final || length > maxControlFramePayloadSize) {
		return w.endMessage(errInvalidControlFrame)
	}
  
	b0 := byte(w.frameType)
	if final {
		b0 |= finalBit
	}
	if w.compress {
		b0 |= rsv1Bit
	}
	w.compress = false
    // 如果是客户端发送那么需要加上掩码
	b1 := byte(0)
	if !c.isServer {
		b1 |= maskBit
	}

	// Assume that the frame starts at beginning of c.writeBuf.
	// 如果是服务端设置为4 客户端设置为0
	framePos := 0
	if c.isServer {
		// Adjust up if mask not included in the header.
		framePos = 4
	}
    // 判断长度 根据长度不同写入不同的数据
	switch {
	case length >= 65536:
		c.writeBuf[framePos] = b0
		c.writeBuf[framePos+1] = b1 | 127
		binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length))
	case length > 125:
		framePos += 6
		c.writeBuf[framePos] = b0
		c.writeBuf[framePos+1] = b1 | 126
		binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length))
	default:
		framePos += 8
		c.writeBuf[framePos] = b0
		c.writeBuf[framePos+1] = b1 | byte(length)
	}
    // 如果是客户端 那么生成掩码的key 并且写入writeBuf中
	if !c.isServer {
		key := newMaskKey()
		// 因为客户端进行掩码,所以
		copy(c.writeBuf[maxFrameHeaderSize-4:], key[:])
		// 对数据进行掩码操作
		maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos])
		if len(extra) > 0 {
			return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode")))
		}
	}

	// Write the buffers to the connection with best-effort detection of
	// concurrent writes. See the concurrency section in the package
	// documentation for more info.

	if c.isWriting {
		panic("concurrent write to websocket connection")
	}
	c.isWriting = true
    // 写入底层连接
	err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra)

	if !c.isWriting {
		panic("concurrent write to websocket connection")
	}
	c.isWriting = false

	if err != nil {
		return w.endMessage(err)
	}
    // 进行结束 写入writePool
	if final {
		w.endMessage(errWriteClosed)
		return nil
	}

	// Setup for next frame.
	w.pos = maxFrameHeaderSize
	w.frameType = continuationFrame
	return nil
}

总结一下write方法,这个可以看出来是服务端和客户端,都会进行调用的方法,然后需要注意的是客户端是有一个掩码的操作,占用了4个字节,然后发送的报文最大是支持8个字节,头部是2个字节。

read
func (w *WsConnection) read() {
	var (
		Data []byte
		err  error
	)
	// 设置读取的最大的字节数量
	w.ws.SetReadLimit(1024)
	// 设置读取的超时时间
	_ = w.ws.SetReadDeadline(time.Now().Add(time.Second * 20))
	for {
	    // 进行读取信息
		if _, Data, err = w.ws.ReadMessage(); err != nil {
			w.close()
			return
		}
		var message interface{}
		if message, err = msg.GetMsgProtocol().Unmarshal(Data); err != nil {
			w.close()
			return
		}
		select {
		case w.readChan <- message:
		case <-w.closeChan:
			return
		}
	}
}

可以看出来是调用ReadMessage。

// ReadMessage is a helper method for getting a reader using NextReader and
// reading from that reader to a buffer.
func (c *Conn) ReadMessage() (messageType int, p []byte, err error) {
	var r io.Reader
	messageType, r, err = c.NextReader()
	if err != nil {
		return messageType, nil, err
	}
	p, err = ioutil.ReadAll(r)
	return messageType, p, err
}
func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
	// Close previous reader, only relevant for decompression.
	if c.reader != nil {
		c.reader.Close()
		c.reader = nil
	}

	c.messageReader = nil
	c.readLength = 0

	for c.readErr == nil {
	    // 读取并且判断类型 这里我们传的类型是BinaryMessage
		frameType, err := c.advanceFrame()
		if err != nil {
			c.readErr = hideTempErr(err)
			break
		}

		if frameType == TextMessage || frameType == BinaryMessage {
			c.messageReader = &messageReader{c}
			c.reader = c.messageReader
			if c.readDecompress {
				c.reader = c.newDecompressionReader(c.reader)
			}
			return frameType, c.reader, nil
		}
	}

	// Applications that do handle the error returned from this method spin in
	// tight loop on connection failure. To help application developers detect
	// this error, panic on repeated reads to the failed connection.
	c.readErrCount++
	if c.readErrCount >= 1000 {
		panic("repeated read on failed websocket connection")
	}

	return noFrame, nil, c.readErr
}

可以看出来这里主要的方法就是advanceFrame。

func (c *Conn) advanceFrame() (int, error) {
	// 1. Skip remainder of previous frame.
    // 这里可以忽略 
	if c.readRemaining > 0 {
		if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil {
			return noFrame, err
		}
	}

	// 2. Read and parse first two bytes of frame header.
	// To aid debugging, collect and report all errors in the first two bytes
	// of the header.

	var errors []string
    // 读取frameType 这里final为true 然后对于客户端mask 是为true
	p, err := c.read(2)
	if err != nil {
		return noFrame, err
	}
    // 判断frameType 注意发送的时候会根据长度不同设置不同的remaining
	frameType := int(p[0] & 0xf)
	final := p[0]&finalBit != 0
	rsv1 := p[0]&rsv1Bit != 0
	rsv2 := p[0]&rsv2Bit != 0
	rsv3 := p[0]&rsv3Bit != 0
	mask := p[1]&maskBit != 0
	c.setReadRemaining(int64(p[1] & 0x7f))

	c.readDecompress = false
	if rsv1 {
		if c.newDecompressionReader != nil {
			c.readDecompress = true
		} else {
			errors = append(errors, "RSV1 set")
		}
	}

	if rsv2 {
		errors = append(errors, "RSV2 set")
	}

	if rsv3 {
		errors = append(errors, "RSV3 set")
	}
    // 判断frameType  这里我们是BinaryMessage
	switch frameType {
	case CloseMessage, PingMessage, PongMessage:
		if c.readRemaining > maxControlFramePayloadSize {
			errors = append(errors, "len > 125 for control")
		}
		if !final {
			errors = append(errors, "FIN not set on control")
		}
	case TextMessage, BinaryMessage:
		if !c.readFinal {
			errors = append(errors, "data before FIN")
		}
		c.readFinal = final
	case continuationFrame:
		if c.readFinal {
			errors = append(errors, "continuation after FIN")
		}
		c.readFinal = final
	default:
		errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
	}

	if mask != c.isServer {
		errors = append(errors, "bad MASK")
	}

	if len(errors) > 0 {
		return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
	}

	// 3. Read and parse frame length as per
	// https://tools.ietf.org/html/rfc6455#section-5.2
	//
	// The length of the "Payload data", in bytes: if 0-125, that is the payload
	// length.
	// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
	// integer are the payload length.
	// - If 127, the following 8 bytes interpreted as
	// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
	// payload length. Multibyte length quantities are expressed in network byte
	// order.
    // 127是长度超过65536 126是超过126
	switch c.readRemaining {
	case 126:
		p, err := c.read(2)
		if err != nil {
			return noFrame, err
		}
       
		if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
			return noFrame, err
		}
	case 127:
		p, err := c.read(8)
		if err != nil {
			return noFrame, err
		}
        
		if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
			return noFrame, err
		}
	}

	// 4. Handle frame masking.

	if mask {
		c.readMaskPos = 0
		p, err := c.read(len(c.readMaskKey))
		if err != nil {
			return noFrame, err
		}
		copy(c.readMaskKey[:], p)
	}

	// 5. For text and binary messages, enforce read limit and return.

	if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {

		c.readLength += c.readRemaining
		// Don't allow readLength to overflow in the presence of a large readRemaining
		// counter.
		if c.readLength < 0 {
			return noFrame, ErrReadLimit
		}

		if c.readLimit > 0 && c.readLength > c.readLimit {
			c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
			return noFrame, ErrReadLimit
		}
        // 这里返回了
		return frameType, nil
	}

	// 6. Read control frame payload.
	// 省略

	// 7. Process control frame payload.
    // 省略

	return frameType, nil
}

可以看出来这里主要就是就是把需要读取的长度读取出来,然后通过setReadRemaining设置进去。因为连接以后会默认有ping和pong的请求,当然这里不是我们考虑的重点。

然后看一下read方法。

func (r *messageReader) Read(b []byte) (int, error) {
	c := r.c
	if c.messageReader != r {
		return 0, io.EOF
	}

	for c.readErr == nil {
	     // 判断是否有读取的数据
		if c.readRemaining > 0 {
			if int64(len(b)) > c.readRemaining {
				b = b[:c.readRemaining]
			}
			// 进行读取
			n, err := c.br.Read(b)
			c.readErr = hideTempErr(err)
			// 因为客户端会对数据进行掩码操作 因此这里解析反解析
			if c.isServer {
				c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
			}
			rem := c.readRemaining
			rem -= int64(n)
			// 设置还需要读取到的数据
			c.setReadRemaining(rem)
			if c.readRemaining > 0 && c.readErr == io.EOF {
				c.readErr = errUnexpectedEOF
			}
			return n, c.readErr
		}

		if c.readFinal {
			c.messageReader = nil
			return 0, io.EOF
		}
        // 继续读取剩余的数据
		frameType, err := c.advanceFrame()
		switch {
		case err != nil:
			c.readErr = hideTempErr(err)
		case frameType == TextMessage || frameType == BinaryMessage:
			c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader")
		}
	}

	err := c.readErr
	if err == io.EOF && c.messageReader == r {
		err = errUnexpectedEOF
	}
	return 0, err
}

这里可以看出来读取也总体不复杂,因为websocket头部很少,需要注意的就是mask的操作。

客户端

因为websocket是双管工的,所以逻辑和服务端是类似的,所以这里就不赘述了。

抓包表现

先看一下抓包的表现,然后分析一下websocket的行为。
首先开启了一个监听8888端口的服务,然后开启一个客户端进行请求。

在这里插入图片描述
这个是第一部分,可以看出来也是三次握手的行为。
握手结束之后。客户端发起请求。
在这里插入图片描述

这里GET请求是通过http请求。

然后就是服务端返回了http code为101.说明是进行协议转换
在这里插入图片描述

接下来就是内容了,是由客户端进行发起。然后内容是五个字节,websocket的头部由websocket标识出来,
在这里插入图片描述

在这里插入图片描述
然后就是tcp层次的ACK。回复收到自己已经收到消息
在这里插入图片描述

总结

websocket是通过客户端发起协议转换,服务返回http code为101后进行协议转换。对于服务端而言头部就是2字节的头部和8字节的数据长度。客户度则是另外加上4字节的掩码,对数据进行转换,然后也是建立在了tcp的基础之上。

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值