反向代理
下游服务器,真实的服务器实现
package main
import (
"fmt"
"io"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
func main() {
rs1 := &RealServer{Addr: "127.0.0.1:2003"}
rs1.Run()
rs2 := &RealServer{Addr: "127.0.0.1:2004"}
rs2.Run()
//监听关闭信号
quit := make(chan os.Signal)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
}
type RealServer struct {
Addr string
}
func (r *RealServer) Run() {
log.Println("Starting httpserver at " + r.Addr)
mux := http.NewServeMux()
mux.HandleFunc("/", r.HelloHandler)
mux.HandleFunc("/base/error", r.ErrorHandler)
mux.HandleFunc("/test_http_string/test_http_string/aaa", r.TimeoutHandler)
server := &http.Server{
Addr: r.Addr,
WriteTimeout: time.Second * 3,
Handler: mux,
}
go func() {
log.Fatal(server.ListenAndServe())
}()
}
func (r *RealServer) HelloHandler(w http.ResponseWriter, req *http.Request) {
//127.0.0.1:8008/abc?sdsdsa=11
//r.Addr=127.0.0.1:8008
//req.URL.Path=/abc
//fmt.Println(req.Host)
upath := fmt.Sprintf("http://%s%s\n", r.Addr, req.URL.Path)
realIP := fmt.Sprintf("RemoteAddr=%s,X-Forwarded-For=%v,X-Real-Ip=%v\n", req.RemoteAddr, req.Header.Get("X-Forwarded-For"), req.Header.Get("X-Real-Ip"))
header:=fmt.Sprintf("headers =%v\n",req.Header)
io.WriteString(w, upath)
io.WriteString(w, realIP)
io.WriteString(w, header)
}
func (r *RealServer) ErrorHandler(w http.ResponseWriter, req *http.Request) {
upath := "error handler"
w.WriteHeader(500)
io.WriteString(w, upath)
}
func (r *RealServer) TimeoutHandler(w http.ResponseWriter, req *http.Request) {
time.Sleep(6*time.Second)
upath := "timeout handler"
w.WriteHeader(200)
io.WriteString(w, upath)
}
代理服务器的实现
package main
import (
"bufio"
"log"
"net/http"
"net/url"
)
var (
proxy_addr = "http://127.0.0.1:2003"
port = "2002"
)
func handler(w http.ResponseWriter, r *http.Request) {
//step 1 解析代理地址,并更改请求体的协议和主机
proxy, err := url.Parse(proxy_addr)
r.URL.Scheme = proxy.Scheme
r.URL.Host = proxy.Host
//step 2 请求下游
transport := http.DefaultTransport
resp, err := transport.RoundTrip(r)
if err != nil {
log.Print(err)
return
}
//step 3 把下游请求内容返回给上游
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
defer resp.Body.Close()
bufio.NewReader(resp.Body).WriteTo(w)
}
func main() {
http.HandleFunc("/", handler)
log.Println("Start serving on port " + port)
err := http.ListenAndServe(":"+port, nil)
if err != nil {
log.Fatal(err)
}
}
package main
import (
"log"
"net/http"
"net/http/httputil"
"net/url"
)
var addr = "127.0.0.1:2002"
func main() {
//127.0.0.1:2002/xxx
//127.0.0.1:2003/base/xxx
rs1 := "http://127.0.0.1:2003/base"
url1, err1 := url.Parse(rs1)
if err1 != nil {
log.Println(err1)
}
proxy := httputil.NewSingleHostReverseProxy(url1)
log.Println("Starting httpserver at " + addr)
log.Fatal(http.ListenAndServe(addr, proxy))
}
更改内容
package main
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"strings"
)
var addr = "127.0.0.1:2002"
func main() {
//127.0.0.1:2002/xxx
//127.0.0.1:2003/base/xxx
rs1 := "http://127.0.0.1:2003/base"
url1, err1 := url.Parse(rs1)
if err1 != nil {
log.Println(err1)
}
proxy := NewSingleHostReverseProxy(url1)
log.Println("Starting httpserver at " + addr)
log.Fatal(http.ListenAndServe(addr, proxy))
}
func NewSingleHostReverseProxy(target *url.URL) *httputil.ReverseProxy {
//http://127.0.0.1:2002/dir?name=123
//RayQuery: name=123
//Scheme: http
//Host: 127.0.0.1:2002
targetQuery := target.RawQuery
director := func(req *http.Request) {
//url_rewrite
//127.0.0.1:2002/dir/abc ==> 127.0.0.1:2003/base/abc ??
//127.0.0.1:2002/dir/abc ==> 127.0.0.1:2002/abc
//127.0.0.1:2002/abc ==> 127.0.0.1:2003/base/abc
re, _ := regexp.Compile("^/dir(.*)");
req.URL.Path = re.ReplaceAllString(req.URL.Path, "$1")
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
//target.Path : /base
//req.URL.Path : /dir
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
if _, ok := req.Header["User-Agent"]; !ok {
req.Header.Set("User-Agent", "")
}
}
modifyFunc := func(res *http.Response) error {
if res.StatusCode != 200 {
return errors.New("error statusCode")
oldPayload, err := ioutil.ReadAll(res.Body)
if err != nil {
return err
}
newPayLoad := []byte("hello " + string(oldPayload))
res.Body = ioutil.NopCloser(bytes.NewBuffer(newPayLoad))
res.ContentLength = int64(len(newPayLoad))
res.Header.Set("Content-Length", fmt.Sprint(len(newPayLoad)))
}
return nil
}
errorHandler := func(res http.ResponseWriter, req *http.Request, err error) {
res.Write([]byte(err.Error()))
}
return &httputil.ReverseProxy{Director: director, ModifyResponse: modifyFunc, ErrorHandler: errorHandler}
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
第一重代理
package main
import (
"bytes"
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"strconv"
"strings"
"time"
)
var addr = "127.0.0.1:2001"
func main() {
rs1 := "http://127.0.0.1:2002"
url1, err1 := url.Parse(rs1)
if err1 != nil {
log.Println(err1)
}
urls := []*url.URL{url1}
proxy := NewMultipleHostsReverseProxy(urls)
log.Println("Starting httpserver at " + addr)
log.Fatal(http.ListenAndServe(addr, proxy))
}
var transport = &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second, //连接超时
KeepAlive: 30 * time.Second, //长连接超时时间
}).DialContext,
MaxIdleConns: 100, //最大空闲连接
IdleConnTimeout: 90 * time.Second, //空闲超时时间
TLSHandshakeTimeout: 10 * time.Second, //tls握手超时时间
ExpectContinueTimeout: 1 * time.Second, //100-continue 超时时间
}
func NewMultipleHostsReverseProxy(targets []*url.URL) *httputil.ReverseProxy {
//请求协调者
director := func(req *http.Request) {
//url_rewrite
//127.0.0.1:2002/dir/abc ==> 127.0.0.1:2003/base/abc ??
//127.0.0.1:2002/dir/abc ==> 127.0.0.1:2002/abc
//127.0.0.1:2002/abc ==> 127.0.0.1:2003/base/abc
re, _ := regexp.Compile("^/dir(.*)");
req.URL.Path = re.ReplaceAllString(req.URL.Path, "$1")
//随机负载均衡
targetIndex := rand.Intn(len(targets))
target := targets[targetIndex]
targetQuery := target.RawQuery
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
// url地址重写:重写前:/aa 重写后:/base/aa
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
if _, ok := req.Header["User-Agent"]; !ok {
req.Header.Set("User-Agent", "user-agent")
}
//只在第一代理中设置此header头
req.Header.Set("X-Real-Ip", req.RemoteAddr)
}
//更改内容
modifyFunc := func(resp *http.Response) error {
//请求以下命令:curl 'http://127.0.0.1:2002/error'
if resp.StatusCode != 200 {
//获取内容
oldPayload, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
//追加内容
newPayload := []byte("StatusCode error:" + string(oldPayload))
resp.Body = ioutil.NopCloser(bytes.NewBuffer(newPayload))
resp.ContentLength = int64(len(newPayload))
resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(newPayload)), 10))
}
return nil
}
//错误回调 :关闭real_server时测试,错误回调
errFunc := func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "ErrorHandler error:"+err.Error(), 500)
}
return &httputil.ReverseProxy{
Director: director,
Transport: transport,
ModifyResponse: modifyFunc,
ErrorHandler: errFunc}
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
第二重代理
package main
import (
"bytes"
"compress/gzip"
"io/ioutil"
"log"
"math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"strconv"
"strings"
"time"
)
var addr = "127.0.0.1:2002"
func main() {
//rs1 := "http://www.baidu.com"
rs1 := "http://127.0.0.1:2003"
url1, err1 := url.Parse(rs1)
if err1 != nil {
log.Println(err1)
}
//rs2 := "http://www.baidu.com"
rs2 := "http://127.0.0.1:2004"
url2, err2 := url.Parse(rs2)
if err2 != nil {
log.Println(err2)
}
urls := []*url.URL{url1, url2}
proxy := NewMultipleHostsReverseProxy(urls)
log.Println("Starting httpserver at " + addr)
log.Fatal(http.ListenAndServe(addr, proxy))
}
var transport = &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second, //连接超时
KeepAlive: 30 * time.Second, //长连接超时时间
}).DialContext,
MaxIdleConns: 100, //最大空闲连接
IdleConnTimeout: 90 * time.Second, //空闲超时时间
TLSHandshakeTimeout: 10 * time.Second, //tls握手超时时间
ExpectContinueTimeout: 1 * time.Second, //100-continue 超时时间
}
func NewMultipleHostsReverseProxy(targets []*url.URL) *httputil.ReverseProxy {
//请求协调者
director := func(req *http.Request) {
//url_rewrite
//127.0.0.1:2002/dir/abc ==> 127.0.0.1:2003/base/abc ??
//127.0.0.1:2002/dir/abc ==> 127.0.0.1:2002/abc
//127.0.0.1:2002/abc ==> 127.0.0.1:2003/base/abc
re, _ := regexp.Compile("^/dir(.*)");
req.URL.Path = re.ReplaceAllString(req.URL.Path, "$1")
//随机负载均衡
targetIndex := rand.Intn(len(targets))
target := targets[targetIndex]
targetQuery := target.RawQuery
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
//todo 部分章节补充1
//todo 当对域名(非内网)反向代理时需要设置此项。当作后端反向代理时不需要
req.Host = target.Host
// url地址重写:重写前:/aa 重写后:/base/aa
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
if _, ok := req.Header["User-Agent"]; !ok {
req.Header.Set("User-Agent", "user-agent")
}
//只在第一代理中设置此header头
//req.Header.Set("X-Real-Ip", req.RemoteAddr)
}
//更改内容
modifyFunc := func(resp *http.Response) error {
//请求以下命令:curl 'http://127.0.0.1:2002/error'
//todo 部分章节功能补充2
//todo 兼容websocket
if strings.Contains(resp.Header.Get("Connection"), "Upgrade") {
return nil
}
var payload []byte
var readErr error
//todo 部分章节功能补充3
//todo 兼容gzip压缩
if strings.Contains(resp.Header.Get("Content-Encoding"), "gzip") {
gr, err := gzip.NewReader(resp.Body)
if err != nil {
return err
}
payload, readErr = ioutil.ReadAll(gr)
resp.Header.Del("Content-Encoding")
} else {
payload, readErr = ioutil.ReadAll(resp.Body)
}
if readErr != nil {
return readErr
}
//异常请求时设置StatusCode
if resp.StatusCode != 200 {
payload = []byte("StatusCode error:" + string(payload))
}
//todo 部分章节功能补充4
//todo 因为预读了数据所以内容重新回写
resp.Body = ioutil.NopCloser(bytes.NewBuffer(payload))
resp.ContentLength = int64(len(payload))
resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(payload)), 10))
return nil
}
//错误回调 :关闭real_server时测试,错误回调
errFunc := func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "ErrorHandler error:"+err.Error(), 500)
}
return &httputil.ReverseProxy{
Director: director,
Transport: transport,
ModifyResponse: modifyFunc,
ErrorHandler: errFunc}
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
真实服务器
package main
import (
"fmt"
"io"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
)
func main() {
rs1 := &RealServer{Addr: "127.0.0.1:2003"}
rs1.Run()
rs2 := &RealServer{Addr: "127.0.0.1:2004"}
rs2.Run()
//监听关闭信号
quit := make(chan os.Signal)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
}
type RealServer struct {
Addr string
}
func (r *RealServer) Run() {
log.Println("Starting httpserver at " + r.Addr)
mux := http.NewServeMux()
mux.HandleFunc("/", r.HelloHandler)
mux.HandleFunc("/base/error", r.ErrorHandler)
mux.HandleFunc("/test_http_string/test_http_string/aaa", r.TimeoutHandler)
server := &http.Server{
Addr: r.Addr,
WriteTimeout: time.Second * 3,
Handler: mux,
}
go func() {
log.Fatal(server.ListenAndServe())
}()
}
func (r *RealServer) HelloHandler(w http.ResponseWriter, req *http.Request) {
//127.0.0.1:8008/abc?sdsdsa=11
//r.Addr=127.0.0.1:8008
//req.URL.Path=/abc
//fmt.Println(req.Host)
upath := fmt.Sprintf("http://%s%s\n", r.Addr, req.URL.Path)
realIP := fmt.Sprintf("RemoteAddr=%s,X-Forwarded-For=%v,X-Real-Ip=%v\n", req.RemoteAddr, req.Header.Get("X-Forwarded-For"), req.Header.Get("X-Real-Ip"))
header:=fmt.Sprintf("headers =%v\n",req.Header)
io.WriteString(w, upath)
io.WriteString(w, realIP)
io.WriteString(w, header)
}
func (r *RealServer) ErrorHandler(w http.ResponseWriter, req *http.Request) {
upath := "error handler"
w.WriteHeader(500)
io.WriteString(w, upath)
}
func (r *RealServer) TimeoutHandler(w http.ResponseWriter, req *http.Request) {
time.Sleep(6*time.Second)
upath := "timeout handler"
w.WriteHeader(200)
io.WriteString(w, upath)
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
if transport == nil {
transport = http.DefaultTransport
}
//step 1 验证请求是否终止?客户端关闭呀
ctx := req.Context()
if cn, ok := rw.(http.CloseNotifier); ok {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()
notifyChan := cn.CloseNotify()
go func() {
select {
case <-notifyChan:
cancel()
case <-ctx.Done():
}
}()
}
//step 2 设置请求 ctx 信息和深拷贝header
outreq := req.Clone(ctx)
if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
}
if outreq.Header == nil {
outreq.Header = make(http.Header) // Issue 33142: historical behavior was to always allocate
}
//step 3 修改req
p.Director(outreq)
outreq.Close = false
//step 4 upgrade头特殊处理
reqUpType := upgradeType(outreq.Header)
removeConnectionHeaders(outreq.Header)
// Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us.
for _, h := range hopHeaders {
outreq.Header.Del(h)
}
// Issue 21096: tell backend applications that care about trailer support
// that we support trailers. (We do, but we don't go out of our way to
// advertise that unless the incoming client request thought it was worth
// mentioning.) Note that we look at req.Header, not outreq.Header, since
// the latter has passed through removeConnectionHeaders.
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
outreq.Header.Set("Te", "trailers")
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
}
//step 5 追加 clientIP
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
prior, ok := outreq.Header["X-Forwarded-For"]
omit := ok && prior == nil // Issue 38079: nil now means don't populate the header
if len(prior) > 0 {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
if !omit {
outreq.Header.Set("X-Forwarded-For", clientIP)
}
}
//step 6 向下游请求信息
res, err := transport.RoundTrip(outreq)
if err != nil {
p.getErrorHandler()(rw, outreq, err)
return
}
//step 7 处理协议升级
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode == http.StatusSwitchingProtocols {
if !p.modifyResponse(rw, res, outreq) {
return
}
p.handleUpgradeResponse(rw, outreq, res)
return
}
//step 8 移除段头部
removeConnectionHeaders(res.Header)
for _, h := range hopHeaders {
res.Header.Del(h)
}
//step 9 修改返回内容
if !p.modifyResponse(rw, res, outreq) {
return
}
//step 10 拷贝头部的数据
copyHeader(rw.Header(), res.Header)
// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
announcedTrailers := len(res.Trailer)
if announcedTrailers > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
trailerKeys = append(trailerKeys, k)
}
rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
//step 11 写入状态码
rw.WriteHeader(res.StatusCode)
//step 12 周期性的刷新内容到响应
err = p.copyResponse(rw, res.Body, p.flushInterval(req, res))
if err != nil {
defer res.Body.Close()
// Since we're streaming the response, if we run into an error all we can do
// is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler
// on read error while copying body.
if !shouldPanicOnCopyError(req) {
p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
return
}
panic(http.ErrAbortHandler)
}
res.Body.Close() // close now, instead of defer, to populate res.Trailer
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
if fl, ok := rw.(http.Flusher); ok {
fl.Flush()
}
}
if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer)
return
}
for k, vv := range res.Trailer {
k = http.TrailerPrefix + k
for _, v := range vv {
rw.Header().Add(k, v)
}
}
}
随机负载代码
package load_balance
import (
"errors"
"fmt"
"math/rand"
"strings"
)
type RandomBalance struct {
curIndex int
rss []string
//观察主体
conf LoadBalanceConf
}
func (r *RandomBalance) Add(params ...string) error {
if len(params) == 0 {
return errors.New("param len 1 at least")
}
addr := params[0]
r.rss = append(r.rss, addr)
return nil
}
func (r *RandomBalance) Next() string {
if len(r.rss) == 0 {
return ""
}
r.curIndex = rand.Intn(len(r.rss))
return r.rss[r.curIndex]
}
func (r *RandomBalance) Get(key string) (string, error) {
return r.Next(), nil
}
func (r *RandomBalance) SetConf(conf LoadBalanceConf) {
r.conf = conf
}
func (r *RandomBalance) Update() {
if conf, ok := r.conf.(*LoadBalanceZkConf); ok {
fmt.Println("Update get conf:", conf.GetConf())
r.rss = []string{}
for _, ip := range conf.GetConf() {
r.Add(strings.Split(ip, ",")...)
}
}
if conf, ok := r.conf.(*LoadBalanceCheckConf); ok {
fmt.Println("Update get conf:", conf.GetConf())
r.rss = nil
for _, ip := range conf.GetConf() {
r.Add(strings.Split(ip, ",")...)
}
}
}
package load_balance
import (
"fmt"
"testing"
)
func TestRandomBalance(t *testing.T) {
rb := &RandomBalance{}
rb.Add("127.0.0.1:2003") //0
rb.Add("127.0.0.1:2004") //1
rb.Add("127.0.0.1:2005") //2
rb.Add("127.0.0.1:2006") //3
rb.Add("127.0.0.1:2007") //4
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
}
轮询负载均衡代码
package load_balance
import (
"errors"
"fmt"
"strings"
)
type RoundRobinBalance struct {
curIndex int
rss []string
//观察主体
conf LoadBalanceConf
}
func (r *RoundRobinBalance) Add(params ...string) error {
if len(params) == 0 {
return errors.New("param len 1 at least")
}
addr := params[0]
r.rss = append(r.rss, addr)
return nil
}
func (r *RoundRobinBalance) Next() string {
if len(r.rss) == 0 {
return ""
}
lens := len(r.rss) //5
if r.curIndex >= lens {
r.curIndex = 0
}
curAddr := r.rss[r.curIndex]
r.curIndex = (r.curIndex + 1) % lens
return curAddr
}
func (r *RoundRobinBalance) Get(key string) (string, error) {
return r.Next(), nil
}
func (r *RoundRobinBalance) SetConf(conf LoadBalanceConf) {
r.conf = conf
}
func (r *RoundRobinBalance) Update() {
if conf, ok := r.conf.(*LoadBalanceZkConf); ok {
fmt.Println("Update get conf:", conf.GetConf())
r.rss = []string{}
for _, ip := range conf.GetConf() {
r.Add(strings.Split(ip, ",")...)
}
}
if conf, ok := r.conf.(*LoadBalanceCheckConf); ok {
fmt.Println("Update get conf:", conf.GetConf())
r.rss = nil
for _, ip := range conf.GetConf() {
r.Add(strings.Split(ip, ",")...)
}
}
}
package load_balance
import (
"fmt"
"testing"
)
func Test_main(t *testing.T) {
rb := &RoundRobinBalance{}
rb.Add("127.0.0.1:2003") //0
rb.Add("127.0.0.1:2004") //1
rb.Add("127.0.0.1:2005") //2
rb.Add("127.0.0.1:2006") //3
rb.Add("127.0.0.1:2007") //4
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
}
加权负载均衡代码实现
package load_balance
import (
"errors"
"fmt"
"strconv"
"strings"
)
type WeightRoundRobinBalance struct {
curIndex int
rss []*WeightNode
rsw []int
//观察主体
conf LoadBalanceConf
}
type WeightNode struct {
addr string
weight int //权重值
currentWeight int //节点当前权重
effectiveWeight int //有效权重
}
func (r *WeightRoundRobinBalance) Add(params ...string) error {
if len(params) != 2 {
return errors.New("param len need 2")
}
parInt, err := strconv.ParseInt(params[1], 10, 64)
if err != nil {
return err
}
node := &WeightNode{addr: params[0], weight: int(parInt)}
node.effectiveWeight = node.weight
r.rss = append(r.rss, node)
return nil
}
func (r *WeightRoundRobinBalance) Next() string {
total := 0
var best *WeightNode
for i := 0; i < len(r.rss); i++ {
w := r.rss[i]
//step 1 统计所有有效权重之和
total += w.effectiveWeight
//step 2 变更节点临时权重为的节点临时权重+节点有效权重
w.currentWeight += w.effectiveWeight
//step 3 有效权重默认与权重相同,通讯异常时-1, 通讯成功+1,直到恢复到weight大小
if w.effectiveWeight < w.weight {
w.effectiveWeight++
}
//step 4 选择最大临时权重点节点
if best == nil || w.currentWeight > best.currentWeight {
best = w
}
}
if best == nil {
return ""
}
//step 5 变更临时权重为 临时权重-有效权重之和
best.currentWeight -= total
return best.addr
}
func (r *WeightRoundRobinBalance) Get(key string) (string, error) {
return r.Next(), nil
}
func (r *WeightRoundRobinBalance) SetConf(conf LoadBalanceConf) {
r.conf = conf
}
func (r *WeightRoundRobinBalance) Update() {
if conf, ok := r.conf.(*LoadBalanceZkConf); ok {
fmt.Println("WeightRoundRobinBalance get conf:", conf.GetConf())
r.rss = nil
for _, ip := range conf.GetConf() {
r.Add(strings.Split(ip, ",")...)
}
}
if conf, ok := r.conf.(*LoadBalanceCheckConf); ok {
fmt.Println("WeightRoundRobinBalance get conf:", conf.GetConf())
r.rss = nil
for _, ip := range conf.GetConf() {
r.Add(strings.Split(ip, ",")...)
}
}
}
package load_balance
import (
"fmt"
"testing"
)
func TestLB(t *testing.T) {
rb := &WeightRoundRobinBalance{}
rb.Add("127.0.0.1:2003", "4") //0
rb.Add("127.0.0.1:2004", "3") //1
rb.Add("127.0.0.1:2005", "2") //2
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
fmt.Println(rb.Next())
}
package load_balance
import (
"errors"
"fmt"
"hash/crc32"
"sort"
"strconv"
"strings"
"sync"
)
type Hash func(data []byte) uint32
type UInt32Slice []uint32
func (s UInt32Slice) Len() int {
return len(s)
}
func (s UInt32Slice) Less(i, j int) bool {
return s[i] < s[j]
}
func (s UInt32Slice) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
type ConsistentHashBanlance struct {
mux sync.RWMutex
hash Hash
replicas int //复制因子
keys UInt32Slice //已排序的节点hash切片
hashMap map[uint32]string //节点哈希和Key的map,键是hash值,值是节点key
//观察主体
conf LoadBalanceConf
}
func NewConsistentHashBanlance(replicas int, fn Hash) *ConsistentHashBanlance {
m := &ConsistentHashBanlance{
replicas: replicas,
hash: fn,
hashMap: make(map[uint32]string),
}
if m.hash == nil {
//最多32位,保证是一个2^32-1环
m.hash = crc32.ChecksumIEEE
}
return m
}
// 验证是否为空
func (c *ConsistentHashBanlance) IsEmpty() bool {
return len(c.keys) == 0
}
// Add 方法用来添加缓存节点,参数为节点key,比如使用IP
func (c *ConsistentHashBanlance) Add(params ...string) error {
if len(params) == 0 {
return errors.New("param len 1 at least")
}
addr := params[0]
c.mux.Lock()
defer c.mux.Unlock()
// 结合复制因子计算所有虚拟节点的hash值,并存入m.keys中,同时在m.hashMap中保存哈希值和key的映射
for i := 0; i < c.replicas; i++ {
hash := c.hash([]byte(strconv.Itoa(i) + addr))
c.keys = append(c.keys, hash)
c.hashMap[hash] = addr
}
// 对所有虚拟节点的哈希值进行排序,方便之后进行二分查找
sort.Sort(c.keys)
return nil
}
// Get 方法根据给定的对象获取最靠近它的那个节点
func (c *ConsistentHashBanlance) Get(key string) (string, error) {
if c.IsEmpty() {
return "", errors.New("node is empty")
}
hash := c.hash([]byte(key))
// 通过二分查找获取最优节点,第一个"服务器hash"值大于"数据hash"值的就是最优"服务器节点"
idx := sort.Search(len(c.keys), func(i int) bool { return c.keys[i] >= hash })
// 如果查找结果 大于 服务器节点哈希数组的最大索引,表示此时该对象哈希值位于最后一个节点之后,那么放入第一个节点中
if idx == len(c.keys) {
idx = 0
}
c.mux.RLock()
defer c.mux.RUnlock()
return c.hashMap[c.keys[idx]], nil
}
func (c *ConsistentHashBanlance) SetConf(conf LoadBalanceConf) {
c.conf = conf
}
func (c *ConsistentHashBanlance) Update() {
if conf, ok := c.conf.(*LoadBalanceZkConf); ok {
fmt.Println("Update get conf:", conf.GetConf())
c.keys = nil
c.hashMap = nil
for _, ip := range conf.GetConf() {
c.Add(strings.Split(ip, ",")...)
}
}
if conf, ok := c.conf.(*LoadBalanceCheckConf); ok {
fmt.Println("Update get conf:", conf.GetConf())
c.keys = nil
c.hashMap = map[uint32]string{}
for _, ip := range conf.GetConf() {
c.Add(strings.Split(ip, ",")...)
}
}
}
package load_balance
import (
"fmt"
"testing"
)
func TestNewConsistentHashBanlance(t *testing.T) {
rb := NewConsistentHashBanlance(10, nil)
rb.Add("127.0.0.1:2003") //0
rb.Add("127.0.0.1:2004") //1
rb.Add("127.0.0.1:2005") //2
rb.Add("127.0.0.1:2006") //3
rb.Add("127.0.0.1:2007") //4
//url hash
fmt.Println(rb.Get("http://127.0.0.1:2002/base/getinfo"))
fmt.Println(rb.Get("http://127.0.0.1:2002/base/error"))
fmt.Println(rb.Get("http://127.0.0.1:2002/base/getinfo"))
fmt.Println(rb.Get("http://127.0.0.1:2002/base/changepwd"))
//ip hash
fmt.Println(rb.Get("127.0.0.1"))
fmt.Println(rb.Get("192.168.0.1"))
fmt.Println(rb.Get("127.0.0.1"))
}
package load_balance
type LbType int
const (
LbRandom LbType = iota
LbRoundRobin
LbWeightRoundRobin
LbConsistentHash
)
func LoadBanlanceFactory(lbType LbType) LoadBalance {
switch lbType {
case LbRandom:
return &RandomBalance{}
case LbConsistentHash:
return NewConsistentHashBanlance(10, nil)
case LbRoundRobin:
return &RoundRobinBalance{}
case LbWeightRoundRobin:
return &WeightRoundRobinBalance{}
default:
return &RandomBalance{}
}
}
func LoadBanlanceFactorWithConf(lbType LbType, mConf LoadBalanceConf) LoadBalance {
//观察者模式
switch lbType {
case LbRandom:
lb := &RandomBalance{}
lb.SetConf(mConf)
mConf.Attach(lb)
lb.Update()
return lb
case LbConsistentHash:
lb := NewConsistentHashBanlance(10, nil)
lb.SetConf(mConf)
mConf.Attach(lb)
lb.Update()
return lb
case LbRoundRobin:
lb := &RoundRobinBalance{}
lb.SetConf(mConf)
mConf.Attach(lb)
lb.Update()
return lb
case LbWeightRoundRobin:
lb := &WeightRoundRobinBalance{}
lb.SetConf(mConf)
mConf.Attach(lb)
lb.Update()
return lb
default:
lb := &RandomBalance{}
lb.SetConf(mConf)
mConf.Attach(lb)
lb.Update()
return lb
}
}
package load_balance
type LoadBalance interface {
Add(...string) error
Get(string) (string, error)
//后期服务发现补充
Update()
}
package main
import (
"bytes"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"time"
)
var (
addr = "127.0.0.1:2002"
transport = &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second, //连接超时
KeepAlive: 30 * time.Second, //长连接超时时间
}).DialContext,
MaxIdleConns: 100, //最大空闲连接
IdleConnTimeout: 90 * time.Second, //空闲超时时间
TLSHandshakeTimeout: 10 * time.Second, //tls握手超时时间
ExpectContinueTimeout: 1 * time.Second, //100-continue状态码超时时间
}
)
func NewMultipleHostsReverseProxy(lb load_balance.LoadBalance) *httputil.ReverseProxy {
//请求协调者
director := func(req *http.Request) {
nextAddr, err := lb.Get(req.RemoteAddr)
if err != nil {
log.Fatal("get next addr fail")
}
target, err := url.Parse(nextAddr)
if err != nil {
log.Fatal(err)
}
targetQuery := target.RawQuery
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
}
if _, ok := req.Header["User-Agent"]; !ok {
req.Header.Set("User-Agent", "user-agent")
}
}
//更改内容
modifyFunc := func(resp *http.Response) error {
//请求以下命令:curl 'http://127.0.0.1:2002/error'
if resp.StatusCode != 200 {
//获取内容
oldPayload, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
//追加内容
newPayload := []byte("StatusCode error:" + string(oldPayload))
resp.Body = ioutil.NopCloser(bytes.NewBuffer(newPayload))
resp.ContentLength = int64(len(newPayload))
resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(newPayload)), 10))
}
return nil
}
//错误回调 :关闭real_server时测试,错误回调
//范围:transport.RoundTrip发生的错误、以及ModifyResponse发生的错误
errFunc := func(w http.ResponseWriter, r *http.Request, err error) {
//todo 如果是权重的负载则调整临时权重
http.Error(w, "ErrorHandler error:"+err.Error(), 500)
}
return &httputil.ReverseProxy{Director: director, Transport: transport, ModifyResponse: modifyFunc, ErrorHandler: errFunc}
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func main() {
rb := load_balance.LoadBanlanceFactory(load_balance.LbWeightRoundRobin)
if err := rb.Add("http://127.0.0.1:2003/base", "10"); err != nil {
log.Println(err)
}
if err := rb.Add("http://127.0.0.1:2004/base", "20"); err != nil {
log.Println(err)
}
proxy := NewMultipleHostsReverseProxy(rb)
log.Println("Starting httpserver at " + addr)
log.Fatal(http.ListenAndServe(addr, proxy))
}
- 可以根据url和ip等多种方式实现负载均衡