参照https://colobu.com/2019/02/23/1m-go-tcp-connection/,使用go编写,基于epoll io复用模型,多协程同时监听端口epoll,基于工作池,防止无限制协程个数
服务端:
serverepoll.go
package main
import (
"flag"
"github.com/libp2p/go-reuseport"
"log"
"net"
"net/http"
"github.com/rcrowley/go-metrics"
"os"
"time"
)
var (
c = flag.Int("c", 10, "concurrency")
)
var (
opsRate = metrics.NewRegisteredMeter("ops", nil)
)
var workerPool *pool
func main() {
//setLimit()
go metrics.Log(metrics.DefaultRegistry, 5*time.Second, log.New(os.Stderr, "metrics: ", log.Lmicroseconds))
go func() {
if err := http.ListenAndServe(":6060", nil); err != nil {
log.Fatalf("pprof failed. %v", err)
}
}()
workerPool = newPool(*c, 100000)
workerPool.start()
for i :=0; i < *c; i++ {
go startEPoll()
}
select {}//阻塞
}
func startEPoll() {
ln, err := reuseport.Listen("tcp", ":8972")
if err != nil {
panic(err)
}
epoller, err := MkEpoll()
if err != nil {
panic(err)
}
go start(epoller)
for {
conn, e := ln.Accept()
if e != nil {
if ne, ok := e.(net.Error); ok && ne.Temporary() {
log.Printf("accept temp err:%v", ne)
continue
}
log.Printf("accept err :%v", e)
return
}
log.Printf("accept a conn")
if err := epoller.Add(conn); err != nil {
log.Printf("failed to add conn %v", err)
conn.Close()
}
}
}
func start(epoller *epoll) {
for {
connections, err := epoller.Wait()
if err != nil {
log.Printf("failed to epoll wait %v", err)
continue
}
for _, conn := range connections {
if conn == nil {
break
}
//_, err = io.CopyN(conn, conn, 8)
//if err != nil {
// if err := epoller.Remove(conn); err != nil {
// log.Printf("failed to remove %v", err)
// }
// conn.Close()
//}
//opsRate.Mark(1)
//log.Printf("recv conn buf %s", string(buf))
var task NTask
task.conn = conn
task.epoller = epoller
workerPool.addTask(task)
}
}
}
workerpool.go
package main
import (
"io"
"log"
"net"
"sync"
)
type NTask struct {
conn net.Conn
epoller *epoll
}
type pool struct {
workers int
maxTasks int
taskQueue chan NTask
mu sync.Mutex
closed bool
done chan struct{}
}
func newPool(w int, t int) *pool {
return &pool{
workers : w,
maxTasks : t,
taskQueue : make(chan NTask, t),
done: make(chan struct{}),
}
}
func (p *pool) Close() {
p.mu.Lock()
p.closed = true
close(p.done)
close(p.taskQueue)
p.mu.Unlock()
}
func (p *pool) addTask(task NTask) {
p.mu.Lock()
if p.closed {
p.mu.Unlock()
return
}
p.mu.Unlock()
p.taskQueue <- task
}
func (p *pool) start() {
for i:=0; i < p.workers; i++ {
go p.startWorker()
}
}
func (p *pool) startWorker() {
for {
select {
case <- p.done:
return
case task := <- p.taskQueue:
if task.conn != nil {
handleConn(task)
}
}
}
}
func handleConn(task NTask) {
_, err := io.CopyN(task.conn, task.conn, 8)
if err != nil {
if err := task.epoller.Remove(task.conn); err != nil {
log.Printf("failed to remove %v", err)
}
task.conn.Close()
}
//log.Printf("read client data")
opsRate.Mark(1)
}
epoll.go
package main
import (
"golang.org/x/sys/unix"
"log"
"net"
"reflect"
"sync"
"syscall"
)
type epoll struct {
fd int
connections map[int]net.Conn
lock *sync.RWMutex
}
func MkEpoll() (*epoll, error) {
fd, err := unix.EpollCreate1(0)
if err != nil {
return nil, err
}
return &epoll{
fd: fd,
lock: &sync.RWMutex{},
connections: make(map[int]net.Conn),
},nil
}
func (e *epoll) Add(conn net.Conn) error {
fd := socketFD(conn)
err := unix.EpollCtl(e.fd, syscall.EPOLL_CTL_ADD, fd, &unix.EpollEvent{Events: unix.POLLIN|unix.POLLHUP, Fd:int32(fd)})
if err != nil {
return nil
}
e.lock.Lock()
defer e.lock.Unlock()
e.connections[fd] = conn
if len(e.connections)%100 == 0 {
log.Printf("tatal number of connections: %v ", len(e.connections))
}
return nil
}
func (e *epoll) Remove(conn net.Conn) error {
fd := socketFD(conn)
err := unix.EpollCtl(e.fd, syscall.EPOLL_CTL_DEL, fd, nil)
if err != nil {
return err
}
e.lock.Lock()
defer e.lock.Unlock()
delete(e.connections, fd)
if len(e.connections)%100 == 0 {
log.Printf("total number of connections:%v", len(e.connections))
}
return nil
}
func (e *epoll) Wait() ([]net.Conn, error) {
events := make([]unix.EpollEvent, 100)
n, err := unix.EpollWait(e.fd, events, 100)
if err != nil {
return nil, err
}
e.lock.RLock()
defer e.lock.RUnlock()
var connections []net.Conn
for i:=0; i < n; i++ {
conn := e.connections[int(events[i].Fd)]
connections = append(connections, conn)
}
return connections, nil
}
//根据net.Conn获取实际的fd值
func socketFD(conn net.Conn) int {
tcpConn := reflect.Indirect(reflect.ValueOf(conn)).FieldByName("conn")//net.Conn 的结构体conn,Indirect用于获取指针指向的值
fdVal := tcpConn.FieldByName("fd")//net.Conn结构体 conn里的fd结构体
pfdVal := reflect.Indirect(fdVal).FieldByName("pfd")//里面的pfd结构体
return int(pfdVal.FieldByName("Sysfd").Int())//里面的Sysfd字段,对应实际的socketfd
}
客户端:
clientepoll.go
package main
import (
"encoding/binary"
"flag"
"fmt"
"log"
"net"
"os"
"time"
"github.com/rcrowley/go-metrics"
)
var (
ip = flag.String("ip", "127.0.0.1", "server IP")
connections = flag.Int("conn", 1, "number of tcp connections")
startMetric = flag.String("sm", time.Now().Format("2020-11-17T10:00:00 -0700"), "start time point of all clients")
)
var (
opsRate = metrics.NewRegisteredTimer("ops", nil)
)
func main() {
flag.Parse()
//setLimit()
go func() {
startPoint, err := time.Parse("2020-11-17T10:00:00 -0700", *startMetric)
if err != nil {
panic(err)
}
time.Sleep(startPoint.Sub(time.Now()))
metrics.Log(metrics.DefaultRegistry, 5*time.Second, log.New(os.Stderr, "metrics: ", log.Lmicroseconds))
}()
addr := *ip + ":8972"
log.Printf("connect to %s", addr)
//使用多个epoll
for i :=0; i < 4; i++ {
go mkClient(addr, *connections/4)
}
select {}//阻塞
}
func mkClient(addr string, connections int) {
epoller, err := MkEpoll()
if err != nil {
panic(err)
}
var conns []net.Conn
for i := 0; i < connections; i++ {
c, err := net.DialTimeout("tcp", addr, 10*time.Second)
if err != nil {
fmt.Println("failed to connect", i, err)
i--
continue
}
if err := epoller.Add(c); err != nil {
log.Printf("failed to add connection %v", err)
c.Close()
}
conns = append(conns, c)
}
log.Printf("init connections %d count", len(conns))
go start(epoller)
tts := time.Second
if connections > 100 {
tts = time.Millisecond * 5
}
for i := 0; i < len(conns); i++ {
time.Sleep(tts)
conn := conns[i]
err = binary.Write(conn, binary.BigEndian, time.Now().UnixNano())
if err != nil {
log.Printf("failed to write timestamp %v", err)
if err := epoller.Remove(conn); err != nil {
if err := epoller.Remove(conn); err != nil {
log.Printf("failed to remove %v", err)
}
}
}
}
select {}
}
func start(epoller *epoll) {
var nano int64
for {
connections, err := epoller.Wait()
if err != nil {
log.Printf("failed to epoll wait %v", err)
continue
}
for _, conn := range connections {
if conn == nil {
break
}
if err := binary.Read(conn, binary.BigEndian, &nano); err != nil {
log.Printf("failed to read %v", err)
if err := epoller.Remove(conn); err != nil{
log.Printf("failed to remove %v", err)
}
conn.Close()
continue
} else {
opsRate.Update(time.Duration(time.Now().UnixNano()- nano))
}
err = binary.Write(conn, binary.BigEndian, time.Now().UnixNano())
if err != nil {
log.Printf("failed to write %v", err)
if err := epoller.Remove(conn); err != nil {
log.Printf("failed to remove %v", err)
}
conn.Close()
}
}
}
}