基础:
golang websocket github地址:
GitHub - gorilla/websocket: A fast, well-tested and widely used WebSocket implementation for Go.
github库官方文档:
websocket package - github.com/gorilla/websocket - Go Packages
golang websocket运行机制以及原理:
https://www.jianshu.com/p/65ef71ddb910
golang websocket在线测试地址:
golang websocket官方客户端以及服务端demo:
websocket/examples/echo at master · gorilla/websocket · GitHub
golang websocket http代理实现:
主要在http代理的基础上(见上一篇博客http代理:golang实现http(s)代理_GT19930910的博客-CSDN博客_golang http 代理),检查是否有http头中带有Upgrade字段的,如果有则说明是websocket请求需要作相应的协议升级,具体细节在代码中呈现遇到的坑也在代码中有详细说明。此处代理的https请求(http请求不会自动302跳转)以及wss请求(wss比ws更加复杂一些需要跳过证书)。
package main
import (
//"fmt"
"net/http"
"io/ioutil"
"strings"
"io"
"crypto/tls"
"net/url"
"github.com/gorilla/websocket"
)
type HttpProxyHandle struct {
}
//类似一个c++类,里面的属性初始化的时候是是可以赋值的
var upgrader = websocket.Upgrader {
//此处给CheckOrigin默认一个返回true保证,否则会出现报错自动跳转
CheckOrigin: func(r *http.Request) bool {
return true
},
} // use default options
func (h *HttpProxyHandle) ServeHTTP(w http.ResponseWriter, req *http.Request) {
config := getConfig()
body, err := ioutil.ReadAll(req.Body)
if err != nil {
//mainLog.Println("body = NULL:", err.Error())
//return,没有数据也是可以的,不需要直接结束
}
//debugLog.Printf("new HttpProxy:")
//fmt.Printf("\nmethod:%s Host:%s RemoteAddr:%s URL:%s, bodycount:%d\n", req.Method, req.Host, req.RemoteAddr, req.URL.String(), len(body))
//debugLog.Printf("target server addr:%s\n", config.HttpProxy[0].Server[0])
//fmt.Println("send to target server:", config.HttpProxy[0].Server[0])
//for k, v := range req.Header {
//fmt.Println(k, v[0])
//}
//此处检测到websocket请求
if req.Header.Get("Upgrade") != "" {
/************进行协议升级************/
upgrader.Subprotocols = []string{req.Header.Get("Sec-WebSocket-Protocol")}
//upgrader.Upgrade内部会返回握手信息,我们做代理需要将dialer.Dial客户端收到的下层返回的握手信息返回给上层,源码Upgrade函数中将
//Sec-WebSocket-Protocol这个头去掉了,所以我们给加上,保证在upgrader.Upgrade调用的上方加上
//func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {}
//upgrader.Subprotocols是属性,upgrader.Upgrade是方法,属性可以初始化,方法在函数定义前进行关联
//具体实现在 https://github.com/gorilla/websocket/blob/master/server.go
c_this, err := upgrader.Upgrade(w, req, nil)
if err != nil {
debugLog.Println("upgrade:", err)
return
}
defer c_this.Close()
/******************启动websocket转发客户端*********/
//此处使用req.URL.Path+req.URL.RawQuer代替 req.URL.String()是因为使用后者之后下面的u.String()使用url加密,导致url错误造成404
u := url.URL{Scheme: "wss", Host: config.HttpProxy[0].Server[0], Path: req.URL.Path, RawQuery: req.URL.RawQuery}
debugLog.Printf("connecting to %s\n", u.String())
//添加头部的时候不能照搬,需要去掉几个固有的,因为库里面已经给你加了,所以代理的时候把重复的去掉,详见源码219行
//https://github.com/gorilla/websocket/blob/master/client.go
headers := make(http.Header)
for k, v := range req.Header {
if k == "Upgrade" ||
k == "Connection" ||
k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions"{
} else {
headers.Set(k, v[0])
//fmt.Println("set ==>", k, v[0])
}
}
dialer := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
c_to_next, resp, err := dialer.Dial(u.String(), headers)
if err != nil {
debugLog.Println("dial:", err)
debugLog.Println("StatusCode:", resp.StatusCode)
}
//fmt.Println(resp.Header)
defer c_to_next.Close()
/*****************接收返回并转回给浏览器*******************/
go func() {
debugLog.Println("run read from next proc")
for {
mt, message, err := c_to_next.ReadMessage()
if err != nil {
debugLog.Println("read from next:", err)
break
}
//fmt.Println("read from next:", message)
err = c_this.WriteMessage(mt, message)
if err != nil {
debugLog.Println("write to priv:", err)
break
}
}
}()
/*****************接收浏览器信息并转发*******************/
//此处不能再协程了,否则会defer c_to_next.Close()
debugLog.Println("run read from priv proc")
for {
mt1, message1, err1 := c_this.ReadMessage()
if err1 != nil {
debugLog.Println("read from priv:", err1)
break
}
//fmt.Println("read from priv:", message1)
err1 = c_to_next.WriteMessage(mt1, message1)
if err1 != nil {
debugLog.Println("write to next:", err1)
break
}
}
/********************************************/
return
}
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
cli := &http.Client{Transport:tr}
reqUrl := "https://" + config.HttpProxy[0].Server[0] + req.URL.String()
proxy_req, err := http.NewRequest(req.Method, reqUrl, strings.NewReader(string(body)))
if err != nil {
debugLog.Println("http.NewRequest(to target server addr):", err.Error())
return
}
//type Header map[string][]string
for k, v := range req.Header {
proxy_req.Header.Set(k, v[0])
//debugLog.Println(k, v[0])
}
countSplit := strings.Split(config.HttpProxy[0].HeadExtend, ",")
for _, v1 := range countSplit {
countSplit := strings.Split(v1, ":")
proxy_req.Header.Add(countSplit[0], countSplit[1])
//debugLog.Printf("HeadExtend %s:%s\n", countSplit[0], countSplit[1])
}
res, err := cli.Do(proxy_req)
if err != nil {
debugLog.Println("cli.Do(req):", err.Error())
return
}
defer res.Body.Close()
for k, v := range res.Header {
w.Header().Set(k, v[0])
//dbeugLog.Println(k, v[0])
}
io.Copy(w, res.Body)
}
func startHttpProxy() {
c := getConfig()
for _, v := range c.HttpProxy {
debugLog.Printf("new HttpProxy. listening%s target:%s\n", v.ListenAddr, v.Server[0])
s := &http.Server{
Addr: v.ListenAddr,
Handler: &HttpProxyHandle{},
}
go func() {
mainLog.Fatal(s.ListenAndServeTLS(c.HttpApi.CrtFile, c.HttpApi.KeyFile))
}()
}
}