本文的代码源自《游戏服务端IOCP模型,自己封装的一个类,3行代码搞定服务端》,我改进过了,希望作者不要说我侵权,我声明这段代码是作者的劳动结晶,我只不过是在此基础上进行了些修改和调试。
windows里有如同Linux中的epoll一般强大的套接字管理功能,即socket编程模型。
我们面对服务器端编程时,往往希望一台主机能同时承接成千上万个客户端连接,只要我们的CPU和内存足够处理业务即可。但对于socket,如果使用select管理,在windows里有最多管理64个套接字的上限,毕竟都是依靠轮询来反馈事件的。如果要管理上百个套接字,我们就需要考虑使用IOCP(完成端口)模型了,见《Windows网络编程》5.2.6 完成端口模型一节的内容。
在经历了2天各种百度学习的情况下,我发现网上对于这个完成端口描述大多都是照本宣科,而且逻辑不完整,同样,书中也有不完整的地方,所以我总结此文,并附带可用的代码供大家参考学习,其中如果有不对的地方,望留言指正!
直接上代码,再说明用法,希望理解完成端口逻辑的同学可以看书或百度:
#pragma once
#include <WinSock2.h>
#include <afxmt.h>
#include <afxtempl.h>
#define ULONG_PTR ULONG
#define PULONG_PTR ULONG*
#define BUFFER_SIZE 1024
#define SOCK_TIMEOUT_SECONDS 60
class Iocp;
typedef enum {
OP_READ = 1,
OP_WRITE = 2,
OP_ACCEPT = 3,
OP_CLOSE = 100,
OP_DO_WORK = 101
} SOCKET_STATE;
typedef struct
{
OVERLAPPED oOverlapped;
WSABUF wsBuffer;
CHAR szBuffer[BUFFER_SIZE];
DWORD dSend;
DWORD dRecv;
SOCKET_STATE sState;
} PER_IO_DATA, *LPPER_IO_DATA;
/*传送给处理函数的参数*/
typedef struct
{
SOCKET sSocket; // 客户端socket描述符
int index; // 序号,用于索引
CHAR key[32]; // ip:port
CHAR szClientIP[24]; // 客户端IP字符串
UINT uiClientPort; // 客户端端口
time_t lastReceiveTime; // 最后接收时间
time_t connectedTime; // 创建链接的时间(如果超过这个时间还没有收到有效的ID,那么关闭)
LPPER_IO_DATA lpIOData; // 释放内存用
Iocp *pIocp; // ServerScanThread要用
CMutex *lpMutex;
} IOCPClient, *LPIOCPClient;
typedef struct
{
int index; // 同IOCPClient的index
CMap<CString, LPCTSTR, IOCPClient*, IOCPClient*> sockets;
} STRU_MAP_ClientSockets;
typedef void (*ReadProc)(LPIOCPClient lpData, LPPER_IO_DATA lpPER_IO_DATA);
typedef VOID (*ScanProc)(LPIOCPClient lpClientSocket);
class Iocp
{
public:
Iocp(const CHAR *host, UINT port);
~Iocp(void);
VOID SetThreadNums();
UINT GetThreadNums();
VOID SetPort(UINT port);
UINT GetPort();
BOOL ListenEx(UINT backlog);
VOID Close();
VOID Iocp::CreateScanThreads();
static VOID ServerWorkThread(VOID *_this);
static VOID ServerScanThread(VOID *s);
static VOID FreeClientSocket(Iocp *lpIocp, LPIOCPClient lpClientSocket);
static int Send(SOCKET sockfd, const char *buff, const unsigned int size);
static VOID SetClientSocketCountText(unsigned int count);
static VOID OutPutLog(const char *szFormat, ...);
VOID SetReadFunc(VOID *lprFun);
VOID SetScanFunc(VOID *lprFun);
int m_ThreadNums; // 线程数量,用于将socket分割到多个区域,扫描时每次只扫描一个区域
int m_AcceptClientIndex; // 接受连接的socket的序号,跟m_ThreadNums取余
STRU_MAP_ClientSockets *m_Sockets; // 因为需要根据线程数动态分配内存,所以不能是静态变量
unsigned int m_SocketCount; // 已连接客户端的数量
ReadProc m_ReadFun; // 读数据回调函数
ScanProc m_ScanFun; // 扫描socket回调函数
HANDLE m_cpHandle; // IO完成端口句柄
// 扩展的接受连接,放在线程里了
static VOID AcceptEx(VOID *_this);
// 监听套接字,即服务端套接字
SOCKET m_ListenSocketID;
};
#include "stdafx.h"
#include "Iocp.h"
#include <stdlib.h>
#include <process.h>
#include "resource.h"
#pragma comment(lib, "ws2_32.lib")
extern void DoRxTxWork(LPIOCPClient lpClientSocket);
Iocp::Iocp(const CHAR *host, UINT port):
m_ListenSocketID(INVALID_SOCKET),
m_AcceptClientIndex(0)
{
SetClientSocketCountText((m_SocketCount = 0));
WSADATA wsaData;
DWORD dwRet = WSAStartup( 0x0202, &wsaData );
if (0 != dwRet )
{
WSACleanup();
throw 1;
}
SOCKADDR_IN sockAddr;
memset( &sockAddr, 0, sizeof(SOCKADDR_IN) ) ;
sockAddr.sin_family = AF_INET;
sockAddr.sin_addr.s_addr = inet_addr(host);
sockAddr.sin_port = htons(port);
/*创建监听套接字*/
m_ListenSocketID = WSASocket( AF_INET, SOCK_STREAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED );
if ( m_ListenSocketID == INVALID_SOCKET )
{
throw 1;
}
/*设置套接字选项*/
CHAR opt = 1;
BOOL ret = setsockopt( m_ListenSocketID , SOL_SOCKET , SO_REUSEADDR , (const CHAR * )&opt , sizeof(opt) );
if ( ret != 0 )
{
throw 1 ;
}
/*绑定套接字*/
if (SOCKET_ERROR == bind(m_ListenSocketID, (const struct sockaddr *)&sockAddr, sizeof(struct sockaddr)))
{
throw 1 ;
}
/*创建完成端口*/
m_cpHandle = CreateIoCompletionPort( INVALID_HANDLE_VALUE, NULL, 0, 0 );
if ( m_cpHandle == NULL )
{
throw 1 ;
}
SYSTEM_INFO mySysInfo;
GetSystemInfo( &mySysInfo );
m_ThreadNums = (int)mySysInfo.dwNumberOfProcessors * 2;
//m_ThreadNums = 1;
m_Sockets = new STRU_MAP_ClientSockets[m_ThreadNums];
for ( int i = 0; i < m_ThreadNums; i++ )
{
m_Sockets[i].index = i;
_beginthread(Iocp::ServerWorkThread, 0, (VOID *)this);
}
TRACE("工作线程准备完成(%d个)\n", m_ThreadNums);
OutPutLog("工作线程准备完成(%d个)\n", m_ThreadNums);
}
Iocp::~Iocp(void)
{
WSACleanup();
}
VOID Iocp::AcceptEx(VOID *_this)
{
SOCKET acSocket;
DWORD dwRecvBytes;
Iocp * pIocp = (Iocp *)_this;
SOCKADDR_IN sAddr;
INT uiClientSize = sizeof(sAddr);
TRACE("服务器已就绪, 套接字=%u ...\n", pIocp->m_ListenSocketID);
OutPutLog("服务器已就绪, 套接字=%u ...\n", pIocp->m_ListenSocketID);
while (TRUE)
{
acSocket = WSAAccept( pIocp->m_ListenSocketID, (SOCKADDR *)&sAddr, &uiClientSize, NULL, 0 );
if ( acSocket == SOCKET_ERROR )
{
TRACE("接受连接发生错误: %d\n", WSAGetLastError());
return;
}
LPIOCPClient lpClientSocket = (LPIOCPClient)malloc(sizeof(IOCPClient));
if ( NULL == lpClientSocket )
{
TRACE("Error while malloc lpClientSocket\n");
return;
}
memset(lpClientSocket, 0, sizeof(IOCPClient));
/*这里停止监听会有问题*/
LPPER_IO_DATA lpIOData = (LPPER_IO_DATA )malloc(sizeof(PER_IO_DATA));
if ( lpIOData == NULL )
{
TRACE("Error while malloc lpIOData\n");
return;
}
memset(lpIOData, 0, sizeof(PER_IO_DATA));
lpClientSocket->connectedTime = lpClientSocket->lastReceiveTime = time(NULL);
lpClientSocket->lpIOData = lpIOData; // 释放内存用
lpClientSocket->sSocket = acSocket;
lpClientSocket->pIocp = pIocp;
strcpy(lpClientSocket->szClientIP, inet_ntoa(sAddr.sin_addr));
lpClientSocket->uiClientPort = sAddr.sin_port;
_snprintf(lpClientSocket->key, sizeof lpClientSocket->key, "%s:%d", lpClientSocket->szClientIP, lpClientSocket->uiClientPort);
lpClientSocket->lpMutex = new CMutex(FALSE, lpClientSocket->key);
if (CreateIoCompletionPort( (HANDLE)acSocket, pIocp->m_cpHandle, (ULONG_PTR)lpClientSocket, 0 ) == NULL)
{
TRACE("Error while CreateIoCompletionPort\n");
return;
}
TRACE("客户端已连接:%s:%u\n", lpClientSocket->szClientIP, lpClientSocket->uiClientPort);
OutPutLog("客户端已连接:%s:%u\n", lpClientSocket->szClientIP, lpClientSocket->uiClientPort);
// 投递线程事件
lpIOData->dSend = 0;
lpIOData->dRecv = 0;
lpIOData->wsBuffer.len = BUFFER_SIZE - 1;
lpIOData->wsBuffer.buf = lpIOData->szBuffer;
lpIOData->sState = OP_READ;
DWORD flags = 0;
if (WSARecv(acSocket, &(lpIOData->wsBuffer), 1, &dwRecvBytes, &flags, &(lpIOData->oOverlapped), NULL) == SOCKET_ERROR)
{
if (WSAGetLastError() != ERROR_IO_PENDING )
{
TRACE("Error ERROR_IO_PENDING\n");
return;
}
else
{
// 客户端按接受连接的顺序依次放入4个线程进行扫描处理
pIocp->m_AcceptClientIndex = (pIocp->m_AcceptClientIndex + 1) % pIocp->m_ThreadNums;
lpClientSocket->index = pIocp->m_AcceptClientIndex;
pIocp->m_Sockets[lpClientSocket->index].sockets[lpClientSocket->key] = lpClientSocket;
SetClientSocketCountText(++pIocp->m_SocketCount);
TRACE("客户端异步读取已完成,等待读取数据...\n");
OutPutLog("客户端异步读取已完成,等待读取数据...\n");
}
}
}
}
BOOL Iocp::ListenEx(UINT backlog)
{
if (SOCKET_ERROR == listen(m_ListenSocketID, backlog))
{
return FALSE;
}
/*创建监听线程*/
if (-1 == _beginthread(Iocp::AcceptEx, 0, (VOID *)this))
{
return FALSE;
}
return TRUE;
}
VOID Iocp:: ServerWorkThread( VOID * _this )
{
Iocp * lpIocp = (Iocp *)_this;
HANDLE hPlePort = (HANDLE)lpIocp->m_cpHandle;
DWORD dwBytes;
LPIOCPClient lpClientSocket = NULL;
LPPER_IO_DATA lpIOData = NULL;
LPOVERLAPPED lpOverlapped = NULL;
DWORD sendBytes = 0;
DWORD recvBytes = 0;
DWORD dwFlag = 0;
while (TRUE)
{
if (0 == GetQueuedCompletionStatus( hPlePort, &dwBytes, (PULONG_PTR)&lpClientSocket, &lpOverlapped, INFINITE ))
{
FreeClientSocket(lpIocp, lpClientSocket);
continue ;
}
lpIOData = (LPPER_IO_DATA)CONTAINING_RECORD(lpOverlapped, PER_IO_DATA, oOverlapped);
if (0 == dwBytes && (lpIOData->sState == OP_READ || lpIOData->sState == OP_WRITE))
{
TRACE("客户端断开了连接:%s\n", lpClientSocket->key);
OutPutLog("客户端断开了连接:%s\n", lpClientSocket->key);
closesocket(lpClientSocket->sSocket);
FreeClientSocket(lpIocp, lpClientSocket);
continue;
}
switch (lpIOData->sState) {
case OP_READ:
lpIOData->dRecv = dwBytes;
lpClientSocket->lastReceiveTime = time(NULL);
lpIocp->m_ReadFun(lpClientSocket, lpIOData);
lpIOData->dRecv = 0;
ZeroMemory( &(lpIOData->oOverlapped), sizeof( OVERLAPPED ) );
lpIOData->wsBuffer.len = BUFFER_SIZE - 1;
lpIOData->wsBuffer.buf = lpIOData->szBuffer;
lpIOData->sState = OP_READ;
if ( WSARecv( lpClientSocket->sSocket, &(lpIOData->wsBuffer), 1, &recvBytes, &dwFlag, &(lpIOData->oOverlapped), NULL ) == SOCKET_ERROR )
{
if ( WSAGetLastError() != ERROR_IO_PENDING )
{
return;
}
}
break;
case OP_WRITE:
// 什么也不用做
break;
case OP_DO_WORK:
lpIocp->m_ScanFun(lpClientSocket);
break;
case OP_CLOSE:
TRACE("主动断开长期无响应的客户端:%s\n", lpClientSocket->key);
OutPutLog("主动断开长期无响应的客户端:%s\n", lpClientSocket->key);
// 这里不能直接释放内存,因为还会触发一次GetQueuedCompletionStatus返回0,在返回0时释放内存
closesocket(lpClientSocket->sSocket);
break;
default:
break;
}
}
}
VOID Iocp::FreeClientSocket(Iocp *lpIocp, LPIOCPClient lpClientSocket)
{
if (NULL == lpIocp || NULL == lpClientSocket) {
return;
}
lpIocp->m_Sockets[lpClientSocket->index].sockets.RemoveKey(lpClientSocket->key);
SetClientSocketCountText(--lpIocp->m_SocketCount);
free(lpClientSocket->lpIOData);
free(lpClientSocket);
TRACE("内存已经释放!\n");
}
VOID Iocp::SetReadFunc(VOID *lprFun)
{
m_ReadFun = (ReadProc)lprFun;
}
VOID Iocp::SetScanFunc(VOID *lprFun)
{
m_ScanFun = (ScanProc)lprFun;
CreateScanThreads();
}
VOID Iocp::CreateScanThreads()
{
STRU_MAP_ClientSockets *sock;
for (int i = 0; i < m_ThreadNums; i++) {
sock = &m_Sockets[i];
_beginthread(Iocp::ServerScanThread, 0, (VOID *)sock);
}
}
VOID Iocp::ServerScanThread(VOID *s)
{
static PER_IO_DATA IOData;
POSITION pos;
CString key;
IOCPClient *lpClientSocket;
STRU_MAP_ClientSockets *mapSock = (STRU_MAP_ClientSockets*)s;
int index = mapSock->index;
int doCount = 0;
CMap<CString, LPCTSTR, IOCPClient*, IOCPClient*> *serverSockets = &mapSock->sockets;
while (1) {
Sleep(5000);
//OutPutLog("序号[%d]定时器开始处理...", index);
doCount = 0;
pos = serverSockets->GetStartPosition();
while (pos) {
doCount++;
serverSockets->GetNextAssoc(pos, key, lpClientSocket);
memset(&IOData, 0, sizeof(PER_IO_DATA));
IOData.sState = OP_DO_WORK;
PostQueuedCompletionStatus(lpClientSocket->pIocp->m_cpHandle, 0, (ULONG_PTR)lpClientSocket, &IOData.oOverlapped);
}
//OutPutLog("序号[%d]定时器处理了%d个客户端", index, doCount);
}
}
void Iocp::SetClientSocketCountText(unsigned int count)
{
CString countStr;
countStr.Format("客户端数量: %u", count);
CWnd *pWnd = AfxGetMainWnd();
HWND hHwnd = pWnd->m_hWnd;
::SetDlgItemText(hHwnd, IDC_CLIENT_COUNT, countStr);
}
void Iocp::OutPutLog(const char *szFormat, ...)
{
static char szLogBuffer[1024];
SYSTEMTIME curTime;
GetLocalTime(&curTime);
CString strTime;
strTime.Format(_T("[%04d-%02d-%02d %02d:%02d:%02d] "),
curTime.wYear,curTime.wMonth,curTime.wDay,
curTime.wHour,curTime.wMinute,curTime.wSecond);
strTime += szFormat;
va_list pArgList;
va_start(pArgList, szFormat);
int len = _vsntprintf(szLogBuffer, sizeof szLogBuffer-2, strTime, pArgList);
va_end(pArgList);
if (szLogBuffer[len-1] == '\n') {
if (szLogBuffer[len-2] != '\r') {
szLogBuffer[len-1] = '\r';
szLogBuffer[len] = '\n';
szLogBuffer[len+1] = '\0';
}
} else {
szLogBuffer[len] = '\r';
szLogBuffer[len+1] = '\n';
szLogBuffer[len+2] = '\0';
}
CWnd *pWnd = AfxGetMainWnd();
CEdit *pEdit = (CEdit*)pWnd->GetDlgItem(IDC_OUTLOG_EDIT);
if (NULL == pEdit) return;
int iTextLen = pEdit->GetWindowTextLength();
pEdit->SetRedraw(FALSE);
pEdit->SetReadOnly(FALSE);
pEdit->SetSel(iTextLen, iTextLen, TRUE);
pEdit->ReplaceSel(szLogBuffer); // 这个函数还是在光标的位置书写
int lineCount = pEdit->GetLineCount(); // m_prlog是绑定CEDIT控件的对象
if(lineCount > 100) // 如果输出日志行太多,则删第一行
{
pEdit->GetWindowText(szLogBuffer,1024 - 1);//只取前100个字符
CString tmp(szLogBuffer);
int it1 = tmp.Find("\r\n") + 2; // 查找第一行的回车换行位置
pEdit->SetSel(0, it1); // 选择要删除的首行
pEdit->ReplaceSel(""); // 用空串替换掉首行
}
pEdit->LineScroll(lineCount); //可用于水平滚动所有行最后一个字符,这只是设置edit进行滚动
pEdit->SetReadOnly(TRUE);
pEdit->SetRedraw(TRUE);
}
int Iocp::Send(SOCKET sockfd, const char *buff, const unsigned int size)
{
static PER_IO_DATA PerIOData;
memset(&PerIOData, 0, sizeof(PER_IO_DATA));
PerIOData.sState = OP_WRITE;
PerIOData.wsBuffer.len = size;
PerIOData.wsBuffer.buf = (char *)buff;
DWORD byteSend = 0;
int ErrorCode;
int result = WSASend(sockfd, &PerIOData.wsBuffer, 1, &byteSend, 0, &PerIOData.oOverlapped, NULL);
if (SOCKET_ERROR == result && ERROR_IO_PENDING != (ErrorCode = WSAGetLastError())) {
TRACE("发送数据出错,错误码: %d\n", ErrorCode);
} else {
TRACE("成功发送数据: %d字节,返回值:%d\n", byteSend, result);
}
return result;
}
// 回调1:客户端的发送的数据会在这个函数通知
void OnRead(LPIOCPClient lpClientSocket, LPPER_IO_DATA lpIOData)
{
if (NULL == lpClientSocket || NULL == lpIOData) {
return;
}
int RxCount = (int) lpIOData->dRecv;
char *RxBuff = lpIOData->szBuffer;
RxBuff[RxCount] = '\0'; // 务必保证接收时留1个字节余量给这个结尾的0
Iocp::OutPutLog("%s:%d: %s\n", lpClientSocket->szClientIP, lpClientSocket->uiClientPort, RxBuff);
Iocp::Send(lpClientSocket->sSocket, RxBuff, RxCount);
}
// 回调2:扫描套接字,目的是关闭闲置套接字,或定时发送心跳包(业务逻辑上要求对方回答)
// 其实这个函数可以直接关闭套接字,只不过通过单独的CLOSE通知会对业务处理更灵活和方便
// 如果你在业务中体会不到,可以直接调用closesocket即可。
VOID OnScan(LPIOCPClient lpClientSocket)
{
static PER_IO_DATA IOData;
if (NULL == lpClientSocket) {
return;
}
if (time(NULL) - lpClientSocket->lastReceiveTime > SOCK_TIMEOUT_SECONDS) {
memset(&IOData, 0, sizeof(PER_IO_DATA));
IOData.sState = OP_CLOSE;
PostQueuedCompletionStatus(lpClientSocket->pIocp->m_cpHandle, 0, (ULONG_PTR)lpClientSocket, &IOData.oOverlapped);
}
}
// 调用的位置,在MFC项目里
void CIOCPSocketDlg::OnRunServer()
{
static Iocp *g_IocpServer = NULL;
if (NULL == g_IocpServer) {
g_IocpServer = new Iocp("0.0.0.0", 8888);
g_IocpServer->SetReadFunc(OnRead); // 回调1,读取套接字发来的内容
g_IocpServer->SetScanFunc(OnScan); // 回调2,定期扫描套接字,可能是业务逻辑要求发心跳包,这个步骤可以免去
g_IocpServer->ListenEx(10);
}
}
代码在MFC项目里,所以还有设置窗体内容的逻辑,大家修改成自己的即可。
补充代码逻辑(书上没讲到的):如果要主动关闭套接字,直接调用closesocket函数即可,因为调用此函数会导致GetQueuedCompletionStatus函数返回0,在返回0的逻辑里释放两个malloc的变量即可。而如果是客户端断开了,GetQueuedCompletionStatus返回的不是0,但满足0 == dwBytes并且是读状态,在这里则除了调用closesocket以外,还要释放malloc的变量(书上讲到了)。
另外:对于同一个套接字,应该不会同时在多个线程里出发读取完成操作,但是很可能在多个线程里出发读取和扫描通知(OP_DO_WORK),所以在业务中,如有必要,要考虑给每一个客户端一个mutex,加锁处理。
使用时,只要实现read,scan两个方法,如果确定不要scan(无法处理从网络中消失的客户端,比如客户端突然死机或断网,或服务器断网一段时间,在此期间客户端主动断开了),那么就不挂载scan函数即可,看SetScanFunc的实现里有新建线程的操作哦:
VOID Iocp::SetScanFunc(VOID *lprFun)
{
m_ScanFun = (ScanProc)lprFun;
CreateScanThreads();
}