golang代理websocket请求

基础:

golang websocket github地址:

GitHub - gorilla/websocket: A fast, well-tested and widely used WebSocket implementation for Go.

github库官方文档:

websocket package - github.com/gorilla/websocket - Go Packages

golang websocket运行机制以及原理:

https://www.jianshu.com/p/65ef71ddb910

golang websocket在线测试地址:

websocket在线测试

golang websocket官方客户端以及服务端demo:

websocket/examples/echo at master · gorilla/websocket · GitHub

golang websocket http代理实现:

主要在http代理的基础上(见上一篇博客http代理:golang实现http(s)代理_GT19930910的博客-CSDN博客_golang http 代理),检查是否有http头中带有Upgrade字段的,如果有则说明是websocket请求需要作相应的协议升级,具体细节在代码中呈现遇到的坑也在代码中有详细说明。此处代理的https请求(http请求不会自动302跳转)以及wss请求(wss比ws更加复杂一些需要跳过证书)。

package main
import (
    //"fmt"
    "net/http"
    "io/ioutil"
    "strings"
    "io"
    "crypto/tls"
    "net/url"
    "github.com/gorilla/websocket"
)

type HttpProxyHandle struct {
}

//类似一个c++类,里面的属性初始化的时候是是可以赋值的
var upgrader = websocket.Upgrader {
    //此处给CheckOrigin默认一个返回true保证,否则会出现报错自动跳转
    CheckOrigin: func(r *http.Request) bool {
        return true
    },
} // use default options

func (h *HttpProxyHandle) ServeHTTP(w http.ResponseWriter, req *http.Request) {

    config := getConfig()
    body, err := ioutil.ReadAll(req.Body)
    if err != nil {
        //mainLog.Println("body = NULL:", err.Error())
        //return,没有数据也是可以的,不需要直接结束
    }
    //debugLog.Printf("new HttpProxy:")
    //fmt.Printf("\nmethod:%s Host:%s RemoteAddr:%s URL:%s, bodycount:%d\n", req.Method, req.Host, req.RemoteAddr, req.URL.String(), len(body))
    //debugLog.Printf("target server addr:%s\n", config.HttpProxy[0].Server[0])
    //fmt.Println("send to target server:", config.HttpProxy[0].Server[0])
    //for k, v := range req.Header {
        //fmt.Println(k, v[0])
    //}
        
    //此处检测到websocket请求
    if req.Header.Get("Upgrade") != "" {
        /************进行协议升级************/
        upgrader.Subprotocols = []string{req.Header.Get("Sec-WebSocket-Protocol")}
        //upgrader.Upgrade内部会返回握手信息,我们做代理需要将dialer.Dial客户端收到的下层返回的握手信息返回给上层,源码Upgrade函数中将
        //Sec-WebSocket-Protocol这个头去掉了,所以我们给加上,保证在upgrader.Upgrade调用的上方加上
        //func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {}
        //upgrader.Subprotocols是属性,upgrader.Upgrade是方法,属性可以初始化,方法在函数定义前进行关联
        //具体实现在 https://github.com/gorilla/websocket/blob/master/server.go
        c_this, err := upgrader.Upgrade(w, req, nil)
        if err != nil {
            debugLog.Println("upgrade:", err)
            return
        }
        defer c_this.Close()
        /******************启动websocket转发客户端*********/
        //此处使用req.URL.Path+req.URL.RawQuer代替 req.URL.String()是因为使用后者之后下面的u.String()使用url加密,导致url错误造成404
        u := url.URL{Scheme: "wss", Host: config.HttpProxy[0].Server[0], Path: req.URL.Path, RawQuery: req.URL.RawQuery}
        debugLog.Printf("connecting to %s\n", u.String())
        //添加头部的时候不能照搬,需要去掉几个固有的,因为库里面已经给你加了,所以代理的时候把重复的去掉,详见源码219行
        //https://github.com/gorilla/websocket/blob/master/client.go
        headers := make(http.Header)
        for k, v := range req.Header {
            if k == "Upgrade" ||
                    k == "Connection" ||
                    k == "Sec-Websocket-Key" ||
                    k == "Sec-Websocket-Version" ||
                    k == "Sec-Websocket-Extensions"{
            } else {
                headers.Set(k, v[0])
                //fmt.Println("set ==>", k, v[0])
            }
        }
        dialer := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
        c_to_next, resp, err := dialer.Dial(u.String(), headers)
        if err != nil {
            debugLog.Println("dial:", err)
            debugLog.Println("StatusCode:", resp.StatusCode)
        }
        //fmt.Println(resp.Header)
        defer c_to_next.Close()
        /*****************接收返回并转回给浏览器*******************/
        go func() {
            debugLog.Println("run read from next proc")
            for {
                mt, message, err := c_to_next.ReadMessage()
                if err != nil {
                    debugLog.Println("read from next:", err)
                    break
                }
                //fmt.Println("read from next:", message)
                err = c_this.WriteMessage(mt, message)
                if err != nil {
                    debugLog.Println("write to priv:", err)
                    break
                }
            }
        }()
        /*****************接收浏览器信息并转发*******************/
        //此处不能再协程了,否则会defer c_to_next.Close()
        debugLog.Println("run read from priv proc")
        for {
            mt1, message1, err1 := c_this.ReadMessage()
            if err1 != nil {
                debugLog.Println("read from priv:", err1)
                break
            }
            //fmt.Println("read from priv:", message1)
            err1 = c_to_next.WriteMessage(mt1, message1)
            if err1 != nil {
                debugLog.Println("write to next:", err1)
                break
            }
        }
        /********************************************/
        return
    }
    tr := &http.Transport{
        TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
    }
    cli := &http.Client{Transport:tr}
    reqUrl := "https://" + config.HttpProxy[0].Server[0] + req.URL.String()
    proxy_req, err := http.NewRequest(req.Method, reqUrl, strings.NewReader(string(body)))
    if err != nil {
        debugLog.Println("http.NewRequest(to target server addr):", err.Error())
        return
    }
    //type Header map[string][]string
    for k, v := range req.Header {
        proxy_req.Header.Set(k, v[0])
        //debugLog.Println(k, v[0])
    }
    countSplit := strings.Split(config.HttpProxy[0].HeadExtend, ",")
    for _, v1 := range countSplit {
        countSplit := strings.Split(v1, ":")
        proxy_req.Header.Add(countSplit[0], countSplit[1])
        //debugLog.Printf("HeadExtend %s:%s\n", countSplit[0], countSplit[1])
    }
    res, err := cli.Do(proxy_req)
    if err != nil {
        debugLog.Println("cli.Do(req):", err.Error())
        return
    }
    defer res.Body.Close()
    for k, v := range res.Header {
        w.Header().Set(k, v[0])
        //dbeugLog.Println(k, v[0])
    }
    io.Copy(w, res.Body)
}

func startHttpProxy() {
    c := getConfig()
    for _, v := range c.HttpProxy {
        debugLog.Printf("new HttpProxy. listening%s target:%s\n", v.ListenAddr, v.Server[0])
        s := &http.Server{
            Addr: v.ListenAddr,
            Handler: &HttpProxyHandle{},
        }
        go func() {
            mainLog.Fatal(s.ListenAndServeTLS(c.HttpApi.CrtFile, c.HttpApi.KeyFile))
        }()
    }
}

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值