支持高并发的基于epoll的go服务器模型

参照https://colobu.com/2019/02/23/1m-go-tcp-connection/,使用go编写,基于epoll io复用模型,多协程同时监听端口epoll,基于工作池,防止无限制协程个数

服务端:

serverepoll.go

package main

import (
   "flag"
   "github.com/libp2p/go-reuseport"
   "log"
   "net"
   "net/http"
   "github.com/rcrowley/go-metrics"
   "os"
   "time"
)

var (
   c = flag.Int("c", 10, "concurrency")
)
var (
   opsRate = metrics.NewRegisteredMeter("ops", nil)
)

var workerPool *pool

func main() {
   //setLimit()
   go metrics.Log(metrics.DefaultRegistry, 5*time.Second, log.New(os.Stderr, "metrics: ", log.Lmicroseconds))

   go func() {
      if err := http.ListenAndServe(":6060", nil); err != nil {
         log.Fatalf("pprof failed. %v", err)
      }
   }()

   workerPool = newPool(*c, 100000)
   workerPool.start()

   for i :=0; i < *c; i++ {
      go startEPoll()
   }

   select {}//阻塞
}

func startEPoll() {
   ln, err := reuseport.Listen("tcp", ":8972")
   if err != nil {
      panic(err)
   }

   epoller, err := MkEpoll()

   if err != nil {
      panic(err)
   }

   go start(epoller)

   for {
      conn, e := ln.Accept()
      if e != nil {
         if ne, ok := e.(net.Error); ok && ne.Temporary() {
            log.Printf("accept temp err:%v", ne)
            continue
         }

         log.Printf("accept err :%v", e)
         return
      }
      log.Printf("accept a conn")
      if err := epoller.Add(conn); err != nil {
         log.Printf("failed to add conn %v", err)
         conn.Close()
      }
   }
}

func start(epoller *epoll) {
   for {
      connections, err := epoller.Wait()
      if err != nil {
         log.Printf("failed to epoll wait %v", err)
         continue
      }

      for _, conn := range connections {
         if conn == nil {
            break
         }

         //_, err = io.CopyN(conn, conn, 8)
         //if err != nil {
         // if err := epoller.Remove(conn); err != nil {
         //    log.Printf("failed to remove %v", err)
         // }
         // conn.Close()
         //}
         //opsRate.Mark(1)
         //log.Printf("recv conn buf %s", string(buf))
         var task NTask
         task.conn = conn
         task.epoller = epoller
         workerPool.addTask(task)
      }
   }
}

workerpool.go

package main

import (
   "io"
   "log"
   "net"
   "sync"
)

type NTask struct {
   conn net.Conn
   epoller *epoll
}

type pool struct {
      workers   int
      maxTasks  int
      taskQueue chan NTask

      mu  sync.Mutex
      closed bool
      done  chan struct{}
}

func newPool(w int, t int) *pool {
   return &pool{
      workers : w,
      maxTasks : t,
      taskQueue : make(chan NTask, t),
      done: make(chan struct{}),
   }
}

func (p *pool) Close() {
   p.mu.Lock()
   p.closed = true
   close(p.done)
   close(p.taskQueue)
   p.mu.Unlock()
}

func (p *pool) addTask(task NTask) {
   p.mu.Lock()
   if p.closed {
      p.mu.Unlock()
      return
   }
   p.mu.Unlock()
   p.taskQueue <- task
}

func (p *pool)     start() {
   for i:=0; i < p.workers; i++ {
      go p.startWorker()
   }
}

func (p *pool) startWorker() {
   for {
      select {
      case <- p.done:
         return
      case task := <- p.taskQueue:
         if task.conn != nil {
            handleConn(task)
         }
      }
   }
}

func handleConn(task NTask) {
   _, err := io.CopyN(task.conn, task.conn, 8)
   if err != nil {
      if err := task.epoller.Remove(task.conn); err != nil {
         log.Printf("failed to remove %v", err)
      }
      task.conn.Close()
   }
   //log.Printf("read client data")
   opsRate.Mark(1)
}

epoll.go

package main

import (
   "golang.org/x/sys/unix"
   "log"
   "net"
   "reflect"
   "sync"
   "syscall"
)

type epoll struct {
   fd  int
   connections  map[int]net.Conn
   lock *sync.RWMutex
}

func MkEpoll() (*epoll, error) {
   fd, err := unix.EpollCreate1(0)
   if err != nil {
      return nil, err
   }

   return &epoll{
      fd: fd,
      lock: &sync.RWMutex{},
      connections: make(map[int]net.Conn),
   },nil

}

func (e *epoll) Add(conn net.Conn) error {
   fd := socketFD(conn)
   err := unix.EpollCtl(e.fd, syscall.EPOLL_CTL_ADD, fd, &unix.EpollEvent{Events: unix.POLLIN|unix.POLLHUP, Fd:int32(fd)})
   if err != nil {
      return nil
   }

   e.lock.Lock()
   defer e.lock.Unlock()
   e.connections[fd] = conn

   if len(e.connections)%100 == 0 {
      log.Printf("tatal number of connections: %v ", len(e.connections))
   }
   return nil
}

func (e *epoll) Remove(conn net.Conn) error {
   fd := socketFD(conn)
   err := unix.EpollCtl(e.fd, syscall.EPOLL_CTL_DEL, fd, nil)
   if err != nil {
      return err
   }

   e.lock.Lock()
   defer e.lock.Unlock()
   delete(e.connections, fd)
   if len(e.connections)%100 == 0 {
      log.Printf("total number of connections:%v", len(e.connections))
   }
   return nil
}

func (e *epoll) Wait() ([]net.Conn, error) {
   events := make([]unix.EpollEvent, 100)
   n, err := unix.EpollWait(e.fd, events, 100)
   if err != nil {
      return nil, err
   }

   e.lock.RLock()
   defer e.lock.RUnlock()

   var connections []net.Conn
   for i:=0; i < n; i++ {
      conn := e.connections[int(events[i].Fd)]
      connections = append(connections, conn)
   }
   return connections, nil
}

//根据net.Conn获取实际的fd值
func socketFD(conn net.Conn) int {
   tcpConn := reflect.Indirect(reflect.ValueOf(conn)).FieldByName("conn")//net.Conn 的结构体conn,Indirect用于获取指针指向的值
   fdVal := tcpConn.FieldByName("fd")//net.Conn结构体 conn里的fd结构体
   pfdVal := reflect.Indirect(fdVal).FieldByName("pfd")//里面的pfd结构体

   return int(pfdVal.FieldByName("Sysfd").Int())//里面的Sysfd字段,对应实际的socketfd
}
客户端:

clientepoll.go

package main

import (
   "encoding/binary"
   "flag"
   "fmt"
   "log"
   "net"
   "os"
   "time"
   "github.com/rcrowley/go-metrics"
)

var (
   ip = flag.String("ip", "127.0.0.1", "server IP")
   connections = flag.Int("conn", 1, "number of tcp connections")
   startMetric = flag.String("sm", time.Now().Format("2020-11-17T10:00:00 -0700"), "start time point of all clients")
)

var (
   opsRate = metrics.NewRegisteredTimer("ops", nil)
)

func main() {
   flag.Parse()
   //setLimit()

   go func() {
      startPoint, err := time.Parse("2020-11-17T10:00:00 -0700", *startMetric)
      if err != nil {
         panic(err)
      }

      time.Sleep(startPoint.Sub(time.Now()))
      metrics.Log(metrics.DefaultRegistry, 5*time.Second, log.New(os.Stderr, "metrics: ", log.Lmicroseconds))
   }()

   addr := *ip + ":8972"
   log.Printf("connect to %s", addr)

   //使用多个epoll
   for i :=0; i < 4; i++ {
      go mkClient(addr, *connections/4)
   }
   select {}//阻塞
}

func mkClient(addr string, connections int) {
   epoller, err := MkEpoll()
   if err != nil {
      panic(err)
   }
   var conns []net.Conn
   for i := 0; i < connections; i++ {
      c, err := net.DialTimeout("tcp", addr, 10*time.Second)
      if err != nil {
         fmt.Println("failed to connect", i, err)
         i--
         continue
      }
      if err := epoller.Add(c); err != nil {
         log.Printf("failed to add connection %v", err)
         c.Close()
      }
      conns = append(conns, c)
   }
   log.Printf("init connections %d count", len(conns))
   go start(epoller)
   tts := time.Second
   if connections > 100 {
      tts = time.Millisecond * 5
   }
   for i := 0; i < len(conns); i++ {
      time.Sleep(tts)
      conn := conns[i]
      err = binary.Write(conn, binary.BigEndian, time.Now().UnixNano())
      if err != nil {
         log.Printf("failed to write timestamp %v", err)
         if err := epoller.Remove(conn); err != nil {
            if err := epoller.Remove(conn); err != nil {
               log.Printf("failed to remove %v", err)
            }
         }
      }
   }
   select {}
}

func start(epoller *epoll) {
   var nano int64
   for {
      connections, err := epoller.Wait()
      if err != nil {
         log.Printf("failed to epoll wait %v", err)
         continue
      }

      for _, conn := range connections {
         if conn == nil {
            break
         }

         if err := binary.Read(conn, binary.BigEndian, &nano); err != nil {
            log.Printf("failed to read %v", err)
            if err := epoller.Remove(conn); err != nil{
               log.Printf("failed to remove %v", err)
            }

            conn.Close()
            continue
         } else {
            opsRate.Update(time.Duration(time.Now().UnixNano()- nano))
         }

         err = binary.Write(conn, binary.BigEndian, time.Now().UnixNano())
         if err != nil {
            log.Printf("failed to write %v", err)
            if err := epoller.Remove(conn); err != nil {
               log.Printf("failed to remove %v", err)
            }
            conn.Close()
         }
      }
   }
}

 

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值