记录一下学习到的东西
一、客户端代码
(一)、客户端利用libpcap抓包,并发送至服务器端,同时开启心跳线程向服务器端定时发送心跳包
1、在入口函数需要有ip和端口以及网卡
2、发送前先以私有协议封装数据包,私有协议中包括了类型、数据包的长度以及数据包
3、心跳包也是利用了私有协议的结构体
4、同时记录了抓包的数量
#include "my_header.h"
int g_client_socket = 0;
pcap_t *g_p_handle = NULL; // 句柄 用于引用已经打开的网络接口
time_t g_current_time_catch, g_current_time_send, g_start_time_catch, g_start_time_send;
int g_count_catch = 0;
int g_flag = 0; // 关闭线程标识
pthread_t g_heartbeat_tid; // 线程的标识符
pthread_mutex_t g_mutex;
// 定义信号处理函数
void sigint_handler(int signum)
{
// 取消线程执行的请求
pthread_cancel(g_heartbeat_tid);
g_flag = 1;
// 等待线程退出
pthread_join(g_heartbeat_tid, NULL);
pcap_close(g_p_handle);
close(g_client_socket);
exit(signum);
}
// 线程函数:心跳包机制
void *heartbeat_thread(void *arg)
{
pri_protocol_t heart_beat = {0};
heart_beat.type = 0x1202abcd;
while (1)
{
if (g_flag == 1)
{
break;
}
// 发送心跳消息
pthread_mutex_lock(&g_mutex);
if (send(g_client_socket, &heart_beat, sizeof(pri_protocol_t), 0) < 0)
{
perror("Failed to send heart\n");
exit(1);
}
pthread_mutex_unlock(&g_mutex);
printf("send heart success\n");
sleep(5);
}
return NULL;
}
// 发送数据包到服务器
void process_packet(u_char *user_data, const struct pcap_pkthdr *packet_heaher, const unsigned char *packet_content)
{
g_count_catch += packet_heaher->caplen; // 记录抓包流量
g_current_time_catch = time(NULL); // 获取当前时间
if (g_current_time_catch - g_start_time_catch >= INTERVAL)
{
printf("Packet capture traffic(/s): %d\n", g_count_catch / INTERVAL);
g_count_catch = 0;
g_start_time_catch = g_current_time_catch; // 更新开始时间
}
private_protocol_t new_packet = {0};
// 填充数据类型、数据长度和数据内容
new_packet.type = 0x12abcdef;
new_packet.length = packet_heaher->caplen;
memcpy(new_packet.data_buffer, packet_content, packet_heaher->caplen);
pthread_mutex_lock(&g_mutex);
if (send(g_client_socket, &new_packet, packet_heaher->caplen + sizeof(pri_protocol_t), 0) < 0)
{
perror("Failed to send data\n");
}
pthread_mutex_unlock(&g_mutex);
}
int main(int argc, char *argv[])
{
char error_buffer[SIZE] = {0}; // 错误信息
struct bpf_program fp;
struct sockaddr_in server_address;
signal(SIGINT, sigint_handler);
pthread_mutex_init(&g_mutex, NULL);
// 输入标准
if (argc < 4)
{
printf("Usage: %s <ip> <port> <device> \n", argv[0]);
}
// 打开网络接口
g_p_handle = pcap_open_live(argv[3], SIZE, 1, 1000, error_buffer); // 开启混杂模式
if (g_p_handle == NULL)
{
fprintf(stderr, "The network interface cannot be opened: %s\n", error_buffer); // to do stderr是标准错误!
exit(EXIT_FAILURE);
}
g_client_socket = socket(AF_INET, SOCK_STREAM, 0);
if (g_client_socket < 0)
{
perror("Failed to create a socket\n"); // 打印错误输出一个描述当前errno代表的错误的字符串
exit(EXIT_FAILURE);
}
// 设置服务器地址参数
server_address.sin_family = AF_INET;
server_address.sin_addr.s_addr = inet_addr(argv[1]);
server_address.sin_port = htons(atoi(argv[2]));
// 建立连接
if (connect(g_client_socket, (struct sockaddr *)&server_address, sizeof(server_address)) < 0)
{
perror("Failed to connect to the server\n");
exit(EXIT_FAILURE);
}
else
{
printf("success connection\n");
if (pthread_create(&g_heartbeat_tid, NULL, heartbeat_thread, NULL) != 0) // 启动线程
{
perror("pthread_create error\n");
exit(EXIT_FAILURE);
}
else
{
printf("success create\n");
}
}
char filter_expression[] = "host 192.168.1.24 and port 1523 and tcp"; // 过滤规则表达式
// 编译过滤规则
if (pcap_compile(g_p_handle, &fp, filter_expression, 0, PCAP_NETMASK_UNKNOWN) == -1)
{
fprintf(stderr, "Failed to compile filter rules: %s\n", pcap_geterr(g_p_handle));
exit(EXIT_FAILURE);
}
// 设置过滤规则
if (pcap_setfilter(g_p_handle, &fp) == -1)
{
fprintf(stderr, "Failed to set filter rules: %s\n", pcap_geterr(g_p_handle));
exit(EXIT_FAILURE);
}
pcap_freecode(&fp);
g_start_time_catch = time(NULL); // 记录程序开始时间
pcap_loop(g_p_handle, -1, process_packet, NULL);
// 关闭网络接口
pthread_cancel(g_heartbeat_tid);
pthread_join(g_heartbeat_tid, NULL);
pcap_close(g_p_handle);
close(g_client_socket);
return 0;
}
二、服务器端代码
(一)、服务器端收客户端发来的数据包以及心跳包,同时需要去处理心跳超时,解析四元组以及源和目标MAC地址。
1、利用epoll处理并发
2、收到包时先取出私有协议,判断类型是心跳还是数据,以及取出数据包的长度方便后续接收数据包部分。如果是心跳包要去重置该客户端的时间戳
3、收到数据包后,存放至环形缓冲区(目前对于多读多写会有问题),读缓冲区的线程负责,从缓冲区中读取出来数据进行解析,并写入文件。
#include "my_header.h"
unsigned char *g_p_recv_buffer = NULL;
circular_buffer_t *g_p_circular_buffer = NULL;
pthread_t g_read_tid; // 线程的标识符
int g_server_fd = 0;
int g_epoll_fd = 0;
int g_flag = 0;
Node_t *g_p_head = NULL;
// 定义信号处理函数
void sigint_handler(int signum)
{
g_flag = 1;
// 发送取消信号给线程
pthread_cancel(g_read_tid);
// 等待线程退出
pthread_join(g_read_tid, NULL);
if (g_p_recv_buffer != NULL)
{
free(g_p_recv_buffer);
g_p_recv_buffer = NULL;
}
if (g_p_circular_buffer != NULL)
{
free(g_p_circular_buffer);
g_p_circular_buffer = NULL;
}
while (g_p_head != NULL)
{
Node_t *p_del = g_p_head;
g_p_head = g_p_head->p_next;
close(g_p_head->args_heart.fd);
free(p_del);
p_del = NULL;
}
close(g_server_fd);
close(g_epoll_fd);
exit(signum);
}
// 创建新节点
Node_t *create_node(heart_beat_t args)
{
Node_t *p_new_node = (Node_t *)malloc(sizeof(Node_t));
if (p_new_node == NULL)
{
printf("memory allocation failed\n");
return NULL; // to do
}
p_new_node->args_heart = args;
p_new_node->p_next = NULL;
return p_new_node;
}
// 在链表末尾插入节点
void insert_node(Node_t **p_current, heart_beat_t args)
{
Node_t *p_new_node = create_node(args);
if (p_new_node == NULL)
{
return;
}
else
{
if (*p_current == NULL)
{
*p_current = p_new_node;
}
else
{
Node_t *p_node = *p_current;
while (p_node->p_next != NULL)
{
p_node = p_node->p_next;
}
p_node->p_next = p_new_node;
}
}
}
// 删除节点
void delete_node(Node_t **p_del_node)
{
Node_t *p_head_left = *p_del_node; // 指向左节点的指针
Node_t *p_current = NULL; // 指向当前节点的指针
p_current = p_head_left->p_next;
p_head_left->p_next = p_current->p_next;
free(p_current);
p_current = NULL;
}
// 信号处理函数,用于定时检测心跳超时
int handle_alarm(time_t heartbeat_time)
{
time_t current_time = time(NULL);
if ((current_time - heartbeat_time) >= (HEARTBEAT_TIMEOUT))
{
printf("Heartbeat timeout. Disconnecting client.\n");
return 1; // 关闭客户端连接
}
return 0;
}
// 初始化缓冲区
void init_circular_buffer()
{
memset(g_p_circular_buffer->buffer, 0, BUFFER_SIZE);
g_p_circular_buffer->read_index = 0;
g_p_circular_buffer->write_index = 0;
g_p_circular_buffer->count = 0;
pthread_mutex_init(&g_p_circular_buffer->mutex, NULL);
}
int close_client(int g_epoll_fd, int client_fd)
{
// 从epoll监听列表中删除套接字
epoll_ctl(g_epoll_fd, EPOLL_CTL_DEL, client_fd, NULL);
close(client_fd);
printf("Client connection closed\n");
continue;
}
// 解析包
void parse_package(const unsigned char *package, int len, char *p_file_name)
{
FILE *p_file = NULL;
p_file = fopen(p_file_name, "a");
if (p_file != NULL)
{
ethhdr_t *ethh = (ethhdr_t *)package; // 以太网头结构体指针
iphdr_t *p_ip_header = (iphdr_t *)(package + 14); // ip头结构体指针
tcp_header_t *p_tcp_header = (tcp_header_t *)(package + 14 + p_ip_header->ihl * 4); // tcp头结构体指针
unsigned char *data = (unsigned char *)(package + 14 + p_ip_header->ihl * 4 + p_tcp_header->doff * 4); // 指向数据的指针
int data_length = len - (14 + p_ip_header->ihl * 4 + p_tcp_header->doff * 4); // 数据的长度
// 解析以太网帧和 IP 数据报中的 MAC 地址和 IP 地址
uint32_t source_ip = ntohl(p_ip_header->saddr);
uint32_t dest_ip = ntohl(p_ip_header->daddr);
// 提取源端口和目标端口
unsigned short source_port = ntohs(p_tcp_header->source_port);
unsigned short dest_port = ntohs(p_tcp_header->dest_port);
fprintf(p_file, "\n");
fprintf(p_file, "源MAC地址:%02X:%02X:%02X:%02X:%02X:%02X\n", ethh->source_mac[0], ethh->source_mac[1], ethh->source_mac[2], ethh->source_mac[3], ethh->source_mac[4], ethh->source_mac[5]);
fprintf(p_file, "目标MAC地址:%02X:%02X:%02X:%02X:%02X:%02X\n", ethh->dest_mac[0], ethh->dest_mac[1], ethh->dest_mac[2], ethh->dest_mac[3], ethh->dest_mac[4], ethh->dest_mac[5]);
fprintf(p_file, "源IP地址: %u.%u.%u.%u\n", (source_ip >> 24) & 0xFF, (source_ip >> 16) & 0xFF, (source_ip >> 8) & 0xFF, source_ip & 0xFF);
fprintf(p_file, "目标IP地址: %u.%u.%u.%u\n", (dest_ip >> 24) & 0xFF, (dest_ip >> 16) & 0xFF, (dest_ip >> 8) & 0xFF, dest_ip & 0xFF);
fprintf(p_file, "源端口:%hu \n", source_port);
fprintf(p_file, "目标端口:%hu \n", dest_port);
fprintf(p_file, "数据:\n");
int i;
for (i = 0; i < data_length; i++)
{
fprintf(p_file, "%02x ", data[i]); // 打印十六进制格式数据
}
fprintf(p_file, "\n");
printf("\n------------------\n");
printf("源MAC地址:%02X:%02X:%02X:%02X:%02X:%02X\n", ethh->source_mac[0], ethh->source_mac[1], ethh->source_mac[2], ethh->source_mac[3], ethh->source_mac[4], ethh->source_mac[5]);
printf("目标MAC地址:%02X:%02X:%02X:%02X:%02X:%02X\n", ethh->dest_mac[0], ethh->dest_mac[1], ethh->dest_mac[2], ethh->dest_mac[3], ethh->dest_mac[4], ethh->dest_mac[5]);
printf("源IP地址: %u.%u.%u.%u\n", (source_ip >> 24) & 0xFF, (source_ip >> 16) & 0xFF, (source_ip >> 8) & 0xFF, source_ip & 0xFF);
printf("目标IP地址: %u.%u.%u.%u\n", (dest_ip >> 24) & 0xFF, (dest_ip >> 16) & 0xFF, (dest_ip >> 8) & 0xFF, dest_ip & 0xFF);
printf("源端口:%hu \n", source_port);
printf("目标端口:%hu \n", dest_port);
printf("数据:\n");
int j;
for (j = 0; j < data_length; j++)
{
printf("%02x ", data[j]); // 打印十六进制格式数据
}
printf("\n------------------\n");
fclose(p_file);
}
else
{
perror("Failed to open file");
}
}
// 判断读指针是否需要回头
void read_turn_head(char *turn_buffer, int data_len)
{
int remaining_data_size = BUFFER_SIZE - g_p_circular_buffer->read_index;
if (remaining_data_size >= data_len)
{
memcpy(turn_buffer, g_p_circular_buffer->buffer + g_p_circular_buffer->read_index, data_len); // 将类私有协议拷贝给read_buffer
}
else
{
// 先读一部分
memcpy(turn_buffer, g_p_circular_buffer->buffer + g_p_circular_buffer->read_index, remaining_data_size);
// 再读开头一部分
memcpy(turn_buffer + remaining_data_size, g_p_circular_buffer->buffer, data_len - remaining_data_size);
}
g_p_circular_buffer->read_index = (g_p_circular_buffer->read_index + data_len) % (BUFFER_SIZE); // 偏移写指针
g_p_circular_buffer->count -= data_len;
}
// 将缓冲区的数据写入文件
void *read_buffer_to_file(void *arg)
{
while (1)
{
pri_protocol_t protocol_buffer = {0}; // 私有协议结构体
char file_name[30] = {0};
unsigned char data_buffer[SIZE]; // 存数据的
if (g_flag == 1)
{
break;
}
if (g_p_circular_buffer->write_index == g_p_circular_buffer->read_index)
{
continue;
}
else
{
pthread_mutex_lock(&g_p_circular_buffer->mutex);
if (g_p_circular_buffer->count > 0)
{
// 判断协议头回头
read_turn_head((char *)&protocol_buffer, sizeof(pri_protocol_t));
// 判断文件名回头
read_turn_head(file_name, sizeof(file_name));
// 判断数据回头
read_turn_head((char *)data_buffer, protocol_buffer.length);
}
pthread_mutex_unlock(&g_p_circular_buffer->mutex);
parse_package(data_buffer, protocol_buffer.length, file_name); // 解析数据包并写入文件
}
}
return NULL;
}
// 处理客户端消息
void write_packet_to_buffer(char *packet, int packet_size)
{
// 缓冲区已经满啦
if (g_p_circular_buffer->count == BUFFER_SIZE) // to do 考虑多读多写情况
{
return;
}
else if (BUFFER_SIZE - g_p_circular_buffer->count >= packet_size)
{
int remaining_size = BUFFER_SIZE - g_p_circular_buffer->write_index;
pthread_mutex_lock(&g_p_circular_buffer->mutex);
// 如果数据包大小小于等于剩余的缓冲区大小,则可以直接复制数据包到缓冲区
if (packet_size <= remaining_size)
{
memcpy(g_p_circular_buffer->buffer + g_p_circular_buffer->write_index, packet, packet_size);
}
// 如果数据包大小大于剩余的缓冲区大小,则需要分两次复制数据包 判断回头
else
{
// 先复制部分数据到末尾的剩余缓冲区
memcpy(g_p_circular_buffer->buffer + g_p_circular_buffer->write_index, packet, remaining_size);
// 再复制剩余的数据到缓冲区的开头
memcpy(g_p_circular_buffer->buffer, packet + remaining_size, packet_size - remaining_size);
}
// 更新写指针的位置
g_p_circular_buffer->write_index = (g_p_circular_buffer->write_index + packet_size) % (BUFFER_SIZE);
g_p_circular_buffer->count += packet_size;
pthread_mutex_unlock(&g_p_circular_buffer->mutex);
}
else
{
return;
}
}
void handle_client(pri_protocol_t protocol_buffer, args_t *args)
{
char filename[30] = {0};
sprintf(filename, "client_%s_%d.txt", args->ip, args->port);
char tmp[sizeof(pri_protocol_t) + sizeof(filename) + protocol_buffer.length]; // 临时存私有协议和文件名和数据的buffer
memcpy(tmp, &protocol_buffer, sizeof(pri_protocol_t));
memcpy(tmp + sizeof(pri_protocol_t), filename, sizeof(filename));
memcpy(tmp + sizeof(pri_protocol_t) + sizeof(filename), g_p_recv_buffer, protocol_buffer.length);
write_packet_to_buffer(tmp, sizeof(pri_protocol_t) + sizeof(filename) + protocol_buffer.length);
}
int main(int argc, char *argv[])
{
int client_fd = 0;
int num_events = 0; // 已经发生的事件数量
int sum = 0; // 计总流量 用于测性能
int recv_count = 0; // 计10s内的包流量
struct sockaddr_in server_address; // 用于表示IPv4的服务器端的socket地址
g_p_recv_buffer = (unsigned char *)malloc(SIZE * sizeof(unsigned char));
memset(g_p_recv_buffer, 0, SIZE * sizeof(unsigned char));
time_t start_time, current_time;
int lost_packet = 0; // 丢包数量
struct epoll_event event = {0}; // 用于注册事件
struct epoll_event events[MAX_EVENTS] = {{0}}; // 存放返回就绪事件
g_p_head = (Node_t *)malloc(sizeof(Node_t));
g_p_circular_buffer = (circular_buffer_t *)malloc(sizeof(circular_buffer_t));
// 初始化
init_circular_buffer();
// 创建写文件线程
pthread_create(&g_read_tid, NULL, read_buffer_to_file, NULL); // 默认属性
// Ctrl C 信号处理
signal(SIGINT, sigint_handler);
// 输入标准
if (argc < 3)
{
printf("Usage: %s <ip> <port>\n", argv[0]);
}
// 创建服务器套接字
g_server_fd = socket(AF_INET, SOCK_STREAM, 0);
if (g_server_fd < 0)
{
perror("Socket creation failed\n");
exit(EXIT_FAILURE);
}
// 设置地址重用选项
int reuse = 1;
if (setsockopt(g_server_fd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)) == -1)
{
perror("Failed to set socket options\n");
close(g_server_fd);
return 1;
}
// 设置服务器地址参数
server_address.sin_family = AF_INET;
server_address.sin_addr.s_addr = inet_addr(argv[1]); // inet_addr函数将ipv4的字符串转换为网络字节序 //to do
server_address.sin_port = htons(atoi(argv[2])); // htons函数将16位无符号整型 主机字节序转为网络字节序
// 绑定服务器地址
if (bind(g_server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0)
{
perror("Binding failed\n");
exit(EXIT_FAILURE);
}
// 监听连接
if (listen(g_server_fd, CONCURRENT_MAX) < 0)
{
perror("Listening failed\n");
exit(EXIT_FAILURE);
}
// 创建 epoll 实例
g_epoll_fd = epoll_create1(0);
if (g_epoll_fd < 0)
{
perror("Failed to create epoll instance\n");
close(g_server_fd);
return 1;
}
// 添加监听套接字到 epoll 实例
event.events = EPOLLIN; // 设置事件类型为读事件 接收数据
event.data.fd = g_server_fd;
// 事件注册
if (epoll_ctl(g_epoll_fd, EPOLL_CTL_ADD, g_server_fd, &event) == -1)
{
perror("Failed to add listen socket to epoll\n");
close(g_server_fd);
close(g_epoll_fd);
return 1;
}
printf("Waiting for client connection...\n");
start_time = time(NULL); // 记录程序开始时间
// 等待事件发生
while (1)
{
// num_events是已经发生的事件数
num_events = epoll_wait(g_epoll_fd, events, MAX_EVENTS, -1); // 多进程中,一个进程中创建的文件描述符默认情况下是继承到子进程的。如果你不希望这种继承发生,可以使用 EPOLL_CLOEXEC 标志
if (num_events == -1)
{
perror("Failed to wait\n");
close(g_server_fd);
close(g_epoll_fd);
return 1;
}
// 处理事件
int i;
for (i = 0; i < num_events; i++)
{
// 监听服务器套接字有可读事件
if (events[i].data.fd == g_server_fd)
{
heart_beat_t heart_beat = {0};
args_t *p_new_node = (args_t *)malloc(sizeof(args_t));
// 接受新的连接
struct sockaddr_in client_addr;
socklen_t client_len = sizeof(client_addr);
// 新的客户端套接字
client_fd = accept(g_server_fd, (struct sockaddr *)&client_addr, &client_len);
if (client_fd == -1)
{
perror("Failed to accept client connection\n");
continue;
}
printf("New client has joined successfully,Socket is %d, IP address is: %s, Port is:%d\n", client_fd, inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));
// 给节点结构体赋值
heart_beat.fd = client_fd;
heart_beat.last_heartbeat = time(NULL);
// 使用尾插法创建新节点
insert_node(&g_p_head, heart_beat);
// 给存ip和端口的结构体赋值
p_new_node->fd = client_fd;
strcpy(p_new_node->ip, inet_ntoa(client_addr.sin_addr)); // 将网络字节序的IPv4地址转换为点分十进制的字符串
p_new_node->port = ntohs(client_addr.sin_port); // 网络字节序转换为主机字节序
// 将新的客户端连接套接字添加到 epoll 实例
event.data.ptr = p_new_node;
if (epoll_ctl(g_epoll_fd, EPOLL_CTL_ADD, client_fd, &event) == -1)
{
perror("Failed to add client socket to epoll\n");
close(client_fd);
continue;
}
}
// 监听客户端套接字有可读事件
else
{
pri_protocol_t re_buffer = {0};
client_fd = ((args_t *)(events[i].data.ptr))->fd;
// 读取数据
ssize_t pri_protocol_byte = recv(client_fd, &re_buffer, sizeof(pri_protocol_t), 0);
if (pri_protocol_byte == sizeof(pri_protocol_t))
{
start_pre:
if (re_buffer.type == 0x12abcdef)
{
if (SIZE >= re_buffer.length)
{
ssize_t packet_byte = recv(client_fd, g_p_recv_buffer, re_buffer.length, 0); // to od htons
if (packet_byte == re_buffer.length)
{
handle_client(re_buffer, events[i].data.ptr); // 处理客户端数据
recv_count += packet_byte;
sum += packet_byte; // 计总数
current_time = time(NULL); // 获取当前时间
if (current_time - start_time >= INTERVAL)
{
printf("Packet receiving traffic(/s):%d\n", recv_count / INTERVAL);
printf("loss packet :%d\n", lost_packet);
recv_count = 0;
start_time = current_time; // 更新开始时间
}
}
else if (packet_byte > 0 && packet_byte < re_buffer.length)
{
int byte = 0;
while (packet_byte == re_buffer.length)
{
byte = recv(client_fd, g_p_recv_buffer + packet_byte, re_buffer.length - packet_byte, 0);
if (byte == 0)
{
close_client(g_epoll_fd, client_fd);
}
else if (byte > 0 || byte < re_buffer.length - packet_byte)
{
packet_byte += byte;
}
else
{
lost_packet++;
perror("Failed to read data from client\n");
continue;
}
}
handle_client(re_buffer, events[i].data.ptr); // 处理客户端数据
recv_count += packet_byte;
sum += packet_byte; // 计总数
current_time = time(NULL); // 获取当前时间
if (current_time - start_time >= INTERVAL)
{
printf("Packet receiving traffic(/s):%d\n", recv_count / INTERVAL);
printf("loss packet :%d\n", lost_packet);
recv_count = 0;
start_time = current_time; // 更新开始时间
}
}
else if (packet_byte == 0)
{
close_client(g_epoll_fd, client_fd);
continue;
}
else
{
lost_packet++;
perror("Failed to read data from client\n");
continue;
}
}
else
{
lost_packet++;
perror("Failed to read data from client\n");
continue;
}
}
else
{
if (re_buffer.type == 0x1202abcd)
{
Node_t *p_current = g_p_head;
while (p_current->args_heart.fd != client_fd && p_current != NULL)
{
p_current = p_current->p_next;
}
p_current->args_heart.last_heartbeat = time(NULL);
}
else
{
lost_packet++;
perror("Failed to read data from client\n");
continue;
}
}
}
else if (pri_protocol_byte > 0 && pri_protocol_byte < sizeof(pri_protocol_t))
{
int byte = 0;
while (pri_protocol_byte == re_buffer.length)
{
byte = recv(client_fd, g_p_recv_buffer + pri_protocol_byte, sizeof(pri_protocol_t) - pri_protocol_byte, 0);
if (byte == 0)
{
close_client(g_epoll_fd, client_fd);
}
else if (byte > 0 || byte < sizeof(pri_protocol_t) - pri_protocol_byte)
{
pri_protocol_byte += byte;
}
else
{
lost_packet++;
perror("Failed to read data from client\n");
continue;
}
}
goto start_pre;
}
else if (pri_protocol_byte == 0)
{
close_client(g_epoll_fd, client_fd);
continue;
}
else
{
lost_packet++;
perror("Failed to read data from client\n");
continue;
}
}
}
// 判断超时
Node_t *p_heart = g_p_head;
while (p_heart->p_next != NULL)
{
int client_heart_fd = p_heart->p_next->args_heart.fd;
time_t last_heartbeat = p_heart->p_next->args_heart.last_heartbeat;
if (client_heart_fd != 0)
{
// 定时器
if (handle_alarm(last_heartbeat) == 1)
{
close_client(g_epoll_fd, client_fd);
delete_node(&p_heart);
}
else
{
p_heart = p_heart->p_next;
}
}
else
{
// 关闭客户端连接
close_client(g_epoll_fd, client_fd);
delete_node(&p_heart);
}
}
}
// 发送取消信号给线程
pthread_cancel(g_read_tid);
// 等待线程退出
pthread_join(g_read_tid, NULL);
if (g_p_recv_buffer != NULL)
{
free(g_p_recv_buffer);
g_p_recv_buffer = NULL;
}
if (g_p_circular_buffer != NULL)
{
free(g_p_circular_buffer);
g_p_circular_buffer = NULL;
}
// 关闭套接字和 epoll 实例
while (g_p_head != NULL)
{
Node_t *p_del = g_p_head;
g_p_head = g_p_head->p_next;
close(g_p_head->args_heart.fd);
free(p_del);
p_del = NULL;
}
close(g_server_fd);
close(g_epoll_fd);
// 等待所有子线程结束
return 0;
}
三、头文件
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <arpa/inet.h>
#include <pthread.h>
#include <time.h>
#include <sys/select.h>
#include <netinet/in.h>
#include <stdint.h>
#include <errno.h>
#include <netinet/ether.h>
#include <netinet/tcp.h>
#include <netinet/ip.h>
#include <math.h>
#include <signal.h>
#include <pcap.h>
#include <netinet/if_ether.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <stdbool.h>
#define BUFFER_SIZE 1024 * 1024 * 6
#define SIZE 1024 * 16
#define INTERVAL 10 // 监听的时间间隔,单位为秒
#define CONCURRENT_MAX 1024 // 应用层同时可以处理的连接
#define MAX_EVENTS 10
#define HEARTBEAT_TIMEOUT 20 // 心跳超时时间(秒)
// 以太网
typedef struct
{
unsigned char dest_mac[6];
unsigned char source_mac[6];
unsigned short protocol;
} ethhdr_t;
// ip
typedef struct
{
uint8_t ihl : 4; // IP
uint8_t version : 4;
uint8_t tos; // 服务类型
uint16_t tot_len; // 总长度
uint16_t id; // 标识
uint16_t frag_off; // 分片偏移 3位的 "标志" 和 13 位的 "片段偏移"
uint8_t ttl; // 生存时间
uint8_t protocol; // 协议
uint16_t check; // 首部校验和
uint32_t saddr; // 源IP地址
uint32_t daddr; // 目的IP地址
} iphdr_t;
// tcp
typedef struct
{
uint16_t source_port; // 源端口号
uint16_t dest_port; // 目的端口号
uint32_t seq; // 序列号
uint32_t ack_seq; // 确认序列号
uint8_t reserved : 4; // 保留字段
uint8_t doff : 4; // 数据偏移
uint8_t flags; // 标志位
uint16_t window; // 窗口大小
uint16_t check_sum; // 校验和
uint16_t urgent_ptr; // 紧急指针
} tcp_header_t;
// 包括数据的私有协议
typedef struct
{
int type; // 数据包类型
int length; // 总长度
const unsigned char data_buffer[SIZE];
} private_protocol_t;
// 不包括数据的私有协议
typedef struct
{
int type; // 数据包类型
int length; // 总长度
} pri_protocol_t;
typedef struct
{
int fd;
char ip[INET_ADDRSTRLEN]; // ip的
unsigned short port; // 端口号
} args_t; // 传参
// 环形缓冲区
typedef struct
{
unsigned char buffer[BUFFER_SIZE]; // 缓冲区buffer
int read_index; // 读的偏移量
int write_index; // 写偏移量
pthread_mutex_t mutex; // 互斥锁
int count; // 队列中当前元素的数量
} circular_buffer_t;
// 心跳机制
typedef struct
{
int fd;
time_t last_heartbeat;
} heart_beat_t;
// 单链表
typedef struct Node_s
{
heart_beat_t args_heart;
struct Node_s *p_next;
} Node_t;