Go语言直接使用Windows的IOCP API写一个echo服务器

6 篇文章 0 订阅

Go的标准库中Windows下的网络是使用了IOCP的,参见go源码go/src/runtime/netpoll_windows.go,标准库为了与Epoll、kqueue等不同平台的IO模式使用统一的API,进行了封装。

如果想直接使用Windows的IOCP API编程,比如想按照:Windows下的高效网络模型IOCP完整示例中的流程写,就需要自行封装IOCP相关的API,虽然标准库中封装了很多系统调用,但是不是很全,而且API的函数签名也有一些问题,比如:

// Deprecated: CreateIoCompletionPort has the wrong function signature. Use x/sys/windows.CreateIoCompletionPort.
func CreateIoCompletionPort(filehandle Handle, cphandle Handle, key uint32, threadcnt uint32) (Handle, error) {
	return createIoCompletionPort(filehandle, cphandle, uintptr(key), threadcnt)
}

// Deprecated: GetQueuedCompletionStatus has the wrong function signature. Use x/sys/windows.GetQueuedCompletionStatus.
func GetQueuedCompletionStatus(cphandle Handle, qty *uint32, key *uint32, overlapped **Overlapped, timeout uint32) error {
	var ukey uintptr
	var pukey *uintptr
	if key != nil {
		ukey = uintptr(*key)
		pukey = &ukey
	}
	err := getQueuedCompletionStatus(cphandle, qty, pukey, overlapped, timeout)
	if key != nil {
		*key = uint32(ukey)
		if uintptr(*key) != ukey && err == nil {
			err = errorspkg.New("GetQueuedCompletionStatus returned key overflow")
		}
	}
	return err
}

// Deprecated: PostQueuedCompletionStatus has the wrong function signature. Use x/sys/windows.PostQueuedCompletionStatus.
func PostQueuedCompletionStatus(cphandle Handle, qty uint32, key uint32, overlapped *Overlapped) error {
	return postQueuedCompletionStatus(cphandle, qty, uintptr(key), overlapped)
}

看了一下,其实内部调用的函数签名是没问题的,可以使用Go的魔法指令go:linkname来解决:

//go:linkname CreateIoCompletionPort syscall.createIoCompletionPort
func CreateIoCompletionPort(fileHandle syscall.Handle, cpHandle syscall.Handle, key uintptr, threadCnt uint32) (handle syscall.Handle, err error)

//go:linkname GetQueuedCompletionStatus syscall.getQueuedCompletionStatus
func GetQueuedCompletionStatus(cpHandle syscall.Handle, qty *uint32, key *uintptr, overlapped **syscall.Overlapped, timeout uint32) (err error)

//go:linkname PostQueuedCompletionStatus syscall.postQueuedCompletionStatus
func PostQueuedCompletionStatus(cphandle syscall.Handle, qty uint32, key uintptr, overlapped *syscall.Overlapped) (err error)

另外还需要使用到一些API,比如WSACreateEventWSAWaitForMultipleEventsWSAResetEventWSAGetOverlappedResult,就需要自行从Ws2_32.dll中装载了:

var (
	modws2_32 = syscall.NewLazyDLL("Ws2_32.dll")

	procWSACreateEvent           = modws2_32.NewProc("WSACreateEvent")
	procWSAWaitForMultipleEvents = modws2_32.NewProc("WSAWaitForMultipleEvents")
	procWSAResetEvent            = modws2_32.NewProc("WSAResetEvent")
	procWSAGetOverlappedResult   = modws2_32.NewProc("WSAGetOverlappedResult")
)

func WSACreateEvent() (Handle syscall.Handle, err error) {
	r1, _, e1 := syscall.SyscallN(procWSACreateEvent.Addr())
	if r1 == 0 {
		err = errnoErr(e1)
	}
	return syscall.Handle(r1), nil
}

func WSAWaitForMultipleEvents(cEvents uint32, lpEvent *syscall.Handle, fWaitAll bool, dwTimeout uint32, fAlertable bool) (uint32, error) {
	var WaitAll, Alertable uint32
	if fWaitAll {
		WaitAll = 1
	}
	if fAlertable {
		Alertable = 1
	}
	r1, _, e1 := syscall.SyscallN(procWSAWaitForMultipleEvents.Addr(), uintptr(cEvents), uintptr(unsafe.Pointer(lpEvent)), uintptr(WaitAll), uintptr(dwTimeout), uintptr(Alertable))
	if r1 == syscall.WAIT_FAILED {
		return 0, errnoErr(e1)
	}
	return uint32(r1), nil
}

func WSAResetEvent(handle syscall.Handle) (err error) {
	r1, _, e1 := syscall.SyscallN(procWSAResetEvent.Addr(), uintptr(handle))
	if r1 == 0 {
		err = errnoErr(e1)
	}
	return
}

func WSAGetOverlappedResult(socket syscall.Handle, overlapped *syscall.Overlapped, transferBytes *uint32, bWait bool, flag *uint32) (err error) {
	var wait uint32
	if bWait {
		wait = 1
	}
	r1, _, e1 := syscall.SyscallN(procWSAGetOverlappedResult.Addr(), uintptr(socket), uintptr(unsafe.Pointer(overlapped)),
		uintptr(unsafe.Pointer(transferBytes)), uintptr(wait), uintptr(unsafe.Pointer(flag)))
	if r1 == 0 {
		err = errnoErr(e1)
	}
	return
}

笔者尝试了下,完全可以,

在这里插入图片描述

直接附上源码:

package main

import (
	"errors"
	"fmt"
	"os"
	"runtime"
	"syscall"
	"unsafe"
	_ "unsafe"
)

//go:linkname CreateIoCompletionPort syscall.createIoCompletionPort
func CreateIoCompletionPort(fileHandle syscall.Handle, cpHandle syscall.Handle, key uintptr, threadCnt uint32) (handle syscall.Handle, err error)

//go:linkname GetQueuedCompletionStatus syscall.getQueuedCompletionStatus
func GetQueuedCompletionStatus(cpHandle syscall.Handle, qty *uint32, key *uintptr, overlapped **syscall.Overlapped, timeout uint32) (err error)

//go:linkname PostQueuedCompletionStatus syscall.postQueuedCompletionStatus
func PostQueuedCompletionStatus(cphandle syscall.Handle, qty uint32, key uintptr, overlapped *syscall.Overlapped) (err error)

//go:linkname errnoErr syscall.errnoErr
func errnoErr(e syscall.Errno) error

var (
	modws2_32 = syscall.NewLazyDLL("Ws2_32.dll")

	procWSACreateEvent           = modws2_32.NewProc("WSACreateEvent")
	procWSACloseEvent            = modws2_32.NewProc("WSACloseEvent")
	procWSAWaitForMultipleEvents = modws2_32.NewProc("WSAWaitForMultipleEvents")
	procWSAResetEvent            = modws2_32.NewProc("WSAResetEvent")
	procWSAGetOverlappedResult   = modws2_32.NewProc("WSAGetOverlappedResult")
)

func WSACreateEvent() (handle syscall.Handle, err error) {
	r1, _, e1 := syscall.SyscallN(procWSACreateEvent.Addr())
	if r1 == 0 {
		err = errnoErr(e1)
	}
	return syscall.Handle(r1), err
}

func WSACloseEvent(handle syscall.Handle) (err error) {
	r1, _, e1 := syscall.SyscallN(procWSACloseEvent.Addr(), uintptr(handle))
	if r1 == 0 {
		err = errnoErr(e1)
	}
	return err
}

func WSAResetEvent(handle syscall.Handle) (err error) {
	r1, _, e1 := syscall.SyscallN(procWSAResetEvent.Addr(), uintptr(handle))
	if r1 == 0 {
		err = errnoErr(e1)
	}
	return
}

func WSAWaitForMultipleEvents(cEvents uint32, lpEvent *syscall.Handle, fWaitAll bool, dwTimeout uint32, fAlertable bool) (uint32, error) {
	var WaitAll, Alertable uint32
	if fWaitAll {
		WaitAll = 1
	}
	if fAlertable {
		Alertable = 1
	}
	r1, _, e1 := syscall.SyscallN(procWSAWaitForMultipleEvents.Addr(), uintptr(cEvents), uintptr(unsafe.Pointer(lpEvent)), uintptr(WaitAll), uintptr(dwTimeout), uintptr(Alertable))
	if r1 == syscall.WAIT_FAILED {
		return 0, errnoErr(e1)
	}
	return uint32(r1), nil
}

func WSAGetOverlappedResult(socket syscall.Handle, overlapped *syscall.Overlapped, transferBytes *uint32, bWait bool, flag *uint32) (err error) {
	var wait uint32
	if bWait {
		wait = 1
	}
	r1, _, e1 := syscall.SyscallN(procWSAGetOverlappedResult.Addr(), uintptr(socket), uintptr(unsafe.Pointer(overlapped)),
		uintptr(unsafe.Pointer(transferBytes)), uintptr(wait), uintptr(unsafe.Pointer(flag)))
	if r1 == 0 {
		err = errnoErr(e1)
	}
	return
}

func SetNonBlock(fd syscall.Handle) error {
	flag := uint32(1)
	size := uint32(unsafe.Sizeof(flag))
	ret := uint32(0)
	ol := syscall.Overlapped{}
	err := syscall.WSAIoctl(fd, 0x8004667e, (*byte)(unsafe.Pointer(&flag)), size, nil, 0, &ret, &ol, 0)
	if err != nil {
		return err
	}
	return nil
}

type IOData struct {
	Overlapped syscall.Overlapped
	WsaBuf     syscall.WSABuf
	NBytes     uint32
	isRead     bool
	cliSock    syscall.Handle
}

func closeIO(data *IOData) {
	if data.Overlapped.HEvent != syscall.Handle(0) {
		WSACloseEvent(data.Overlapped.HEvent)
		data.Overlapped.HEvent = syscall.Handle(0)
	}
	syscall.Closesocket(data.cliSock)
}

func main() {
	listenFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
	if err != nil {
		return
	}
	defer func() {
		syscall.Closesocket(listenFd)
		syscall.WSACleanup()
	}()
	v4 := &syscall.SockaddrInet4{
		Port: 6000,
		Addr: [4]byte{},
	}
	err = syscall.Bind(listenFd, v4)
	if err != nil {
		return
	}
	err = syscall.Listen(listenFd, 0)
	if err != nil {
		return
	}

	hIOCP, err := CreateIoCompletionPort(syscall.InvalidHandle, 0, 0, 0)
	if err != nil {
		return
	}
	count := runtime.NumCPU()
	for i := 0; i < count; i++ {
		go workThread(hIOCP)
	}

	defer func() {
		for i := 0; i < count; i++ {
			PostQueuedCompletionStatus(hIOCP, 0, 0, nil)
		}
	}()

	for {
		acceptFd, er := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
		if er != nil {
			break
		}
		b := make([]byte, 1024)
		recvD := uint32(0)
		data := &IOData{
			Overlapped: syscall.Overlapped{},
			WsaBuf: syscall.WSABuf{
				Len: 1024,
				Buf: &b[0],
			},
			NBytes:  1024,
			isRead:  true,
			cliSock: acceptFd,
		}
		data.Overlapped.HEvent, er = WSACreateEvent()
		if er != nil {
			fmt.Printf("WSACreateEvent failed:%s", er)
			closeIO(data)
			break
		}

		size := uint32(unsafe.Sizeof(&syscall.SockaddrInet4{}) + 16)
		er = syscall.AcceptEx(listenFd, acceptFd, data.WsaBuf.Buf, data.WsaBuf.Len-size*2, size, size, &recvD, &data.Overlapped)
		if er != nil && !errors.Is(er, syscall.ERROR_IO_PENDING) {
			er = os.NewSyscallError("AcceptEx", er)
			fmt.Printf("AcceptEx Error:%s", er)
			closeIO(data)
			break
		}

		_, er = WSAWaitForMultipleEvents(1, &data.Overlapped.HEvent, true, 100, false)
		if er != nil {
			fmt.Printf("WSAWaitForMultipleEvents Error:%s", er)
			closeIO(data)
			break
		}
		WSAResetEvent(data.Overlapped.HEvent)
		flag := uint32(0)
		er = WSAGetOverlappedResult(acceptFd, &data.Overlapped, &data.NBytes, true, &flag)
		if er != nil {
			fmt.Printf("WSAGetOverlappedResult Error:%s", er)
			closeIO(data)
			break
		}
		if data.NBytes == 0 {
			closeIO(data)
			continue
		}
		fmt.Printf("client %d connected\n", acceptFd)
		_, err = CreateIoCompletionPort(acceptFd, hIOCP, 0, 0)
		if err != nil {
			fmt.Printf("CreateIoCompletionPort Error:%s", er)
			closeIO(data)
			break
		}
		postWrite(data)
	}
}

func postWrite(data *IOData) (err error) {
	data.isRead = false
	// 这里输出一下data指针,让运行时不把data给GC掉,否则就会出问题
	fmt.Printf("%p cli:%d send %s\n", data, data.cliSock, unsafe.String(data.WsaBuf.Buf, data.NBytes))
	err = syscall.WSASend(data.cliSock, &data.WsaBuf, 1, &data.NBytes, 0, &data.Overlapped, nil)
	if err != nil && !errors.Is(err, syscall.ERROR_IO_PENDING) {
		fmt.Printf("cli:%d send failed: %s\n", data.cliSock, err)
		closeIO(data)
		return err
	}
	return
}

func postRead(data *IOData) (err error) {
	data.NBytes = data.WsaBuf.Len
	data.isRead = true
	flag := uint32(0)
	err = syscall.WSARecv(data.cliSock, &data.WsaBuf, 1, &data.NBytes, &flag, &data.Overlapped, nil)
	if err != nil && !errors.Is(err, syscall.ERROR_IO_PENDING) {
		fmt.Printf("cli:%d receive failed: %s\n", data.cliSock, err)
		closeIO(data)
		return err
	}
	return
}

func workThread(hIOCP syscall.Handle) {
	var pOverlapped *syscall.Overlapped
	var ioSize uint32
	var key uintptr
	for {
		err := GetQueuedCompletionStatus(hIOCP, &ioSize, &key, &pOverlapped, syscall.INFINITE)
		if err != nil {
			fmt.Printf("GetQueuedCompletionStatus failed: %s\n", err)
			return
		}
		ioData := (*IOData)(unsafe.Pointer(pOverlapped))
		if ioSize == 0 {
			closeIO(ioData)
			break
		}
		ioData.NBytes = ioSize
		if ioData.isRead {
			postWrite(ioData)
		} else {
			postRead(ioData)
		}
	}
}

源码只是一个示例,可能有不完善的地方,感兴趣的读者可以自行完善。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值