linux (epoll + 线程池)聊天室(在上一版的思路改进简化版本聊天室)

 客户端:

#ifndef _CHAT_C_H_
#define _CHAT_C_H_

#define _DEBUG_
#ifndef _DEBUG_
#define debug_msg(fmt, args...)
#else
#define debug_msg(fmt, args...) printf(fmt, ##args)
#endif

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <sys/epoll.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <fcntl.h>

#include <netinet/in.h>
#include <ctype.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <pthread.h>
#include <sqlite3.h>
#include <errno.h>



#define IPADDR "127.0.0.1"

#define TCP_PORT 9000
#define UDP_PORT 9000
#define MAX_EVENTS 20
#define MSG_BUFFSIZE 100
#define MODE_SIZE 5
#define PASSWORD_SIZE 10
#define NICKNAME_SIZE 10
#define SC_PRO 10
#define ANSWER 10
#define MALLOC_OK 0
#define MALLOC_NO -1
#define SQL_SIZE 300
#define MAX_GROUP 20
#define MUSHIN -1
#define ACCOUNT_ERROR -1
#define ANSWER_ERROR -2

enum option
{
    EXIT,
    ERO,
    LOG,
    FGPD
};
enum chatmode
{
    STOO = 1,
    STOA,
};


// 注册信息
struct enrollinfo
{
    char nickname[NICKNAME_SIZE];
    char password[PASSWORD_SIZE];
    char sc_protect[SC_PRO];
    char answer[ANSWER];
    int account;
};

// 登录信息
struct logininfo
{
    int account;
    char password[PASSWORD_SIZE];
};

// 忘记密码信息
struct forgetpswdinfo
{
    char nickname[NICKNAME_SIZE];
    int account;
    char question[SC_PRO];
    char answer[ANSWER];
    char password[PASSWORD_SIZE];
    int flag;
};

// 注册登录选项包
struct options
{
    int option;
    union
    {
        struct enrollinfo eninfo;
        struct logininfo loginfo;
        struct forgetpswdinfo pswdinfo;
    } info;
};

// 聊天消息包
struct msg_buff
{
    int chat_mode;
    int my_account;
    int account;                     // 对方账号
    char nickname[NICKNAME_SIZE];    // 自己昵称
    char message[MSG_BUFFSIZE];
};

// 回调函数参数(套接口和数据库文件描述符)
struct cp
{
    int confd;
    int account;
    struct sockaddr_in s_addr;
};


// 创建tcp/udp套接口
void create_tcp_client(int *sockfd, struct sockaddr_in *s_addr);
void create_udp_client(int *sockfd, struct sockaddr_in *s_addr);

// 菜单选项
void options_menu();

// 用户选择
void user_options(const int confd, struct sockaddr_in s_addr);

// 操作
void ero_account_c(int confd, struct options *option);

// 登录
void log_opreation_c(int confd, struct sockaddr_in s_addr, struct options *option);
void log_infomation(struct options *option);
void chat_start_c(struct options *option, struct sockaddr_in s_addr);

// 找回密码
void get_password(int confd, struct options *option);

// 读写分离
void *read_thread(void *arg);
void *write_thread(void *arg);

void say_to_all(int confd, struct sockaddr_in s_addr, struct msg_buff *msg);
void say_to_one(int confd, struct sockaddr_in s_addr, struct msg_buff *msg);



#endif
#include "chat_c.h"

void create_udp_client(int *sockfd, struct sockaddr_in *s_addr)
{
    socklen_t s_len = sizeof(struct sockaddr_in);
    *sockfd = socket(AF_INET, SOCK_DGRAM, 0);
    if(*sockfd < 0)
    {
        perror("sockfd creat error:");
        exit(1);
    }   

    bzero(s_addr, sizeof(struct sockaddr_in));
    (*s_addr).sin_family = AF_INET;
    (*s_addr).sin_port = htons(UDP_PORT);
    (*s_addr).sin_addr.s_addr = inet_addr(IPADDR);

}
void create_tcp_client(int *sockfd, struct sockaddr_in *s_addr)
{
    socklen_t s_len = sizeof(struct sockaddr_in);
    *sockfd = socket(AF_INET, SOCK_STREAM, 0);
    if(*sockfd < 0)
    {
        perror("sockfd creat error:");
        exit(1);
    }   

    bzero(s_addr, sizeof(struct sockaddr_in));
    (*s_addr).sin_family = AF_INET;
    (*s_addr).sin_port = htons(TCP_PORT);
    (*s_addr).sin_addr.s_addr = inet_addr(IPADDR);
}


int main()
{
    int sockfd_tcp;

    sqlite3 *pdb;


    struct sockaddr_in s_addr_tcp;

    create_tcp_client(&sockfd_tcp, &s_addr_tcp);


    if (connect(sockfd_tcp, (struct sockaddr *)&s_addr_tcp, sizeof(s_addr_tcp)) < 0)
    {
        perror("connect error");
        exit(1);
    }


    user_options(sockfd_tcp, s_addr_tcp);

    close(sockfd_tcp);

    return 0;
}
#include "chat_c.h"

void options_menu()
{
    printf("********\n");
    printf("1.注册\n");
    printf("2.登录\n");
    printf("0.退出\n");
    printf("********\n");
}


void user_options(const int confd, struct sockaddr_in s_addr)
{
    int flag = 0;
    struct options option;

    do
    {
        memset(&option, 0, sizeof(option));
        options_menu();
        printf("请选择:> ");
        scanf("%d", &option.option);
        getchar();
        switch (option.option)
        {
        case ERO:
            ero_account_c(confd, &option);
            break;

        case LOG:
            log_opreation_c(confd, s_addr, &option);
            break;

        case FGPD:
            get_password(confd, &option);
            break;

        case EXIT:
            option.option = EXIT;
            write(confd, &option, sizeof(option));
            break;

        default:
            printf("input error:\n");
            break;
        }
    } while (option.option);
    return;
}

void ero_account_c(const int confd, struct options *option)
{

    int account;
    printf("请输入昵称: ");
    scanf("%s", option->info.eninfo.nickname);
    printf("请输入密码: ");
    scanf("%s", option->info.eninfo.password);
    printf("密保问题 : ");
    scanf("%s", option->info.eninfo.sc_protect);
    printf("密保答案 : ");
    scanf("%s", option->info.eninfo.answer);
    write(confd, option, sizeof(struct options));
    read(confd, &account, sizeof(int));
    printf("你的账号是 :%d\n", account);
    return;
}

void log_infomation(struct options *option)
{
    int i = 0;
    char a[20];
    printf("请输入账号: ");
    scanf("%d", &option->info.loginfo.account);
    getchar();
    printf("请输入密码: ");
    fflush(stdout);
    while (1)
    {
        system("stty -echo");
        a[i] = getchar();
        if (a[i] == '\n')
        {
            break;
        }
        system("stty echo");
        printf("*");
        fflush(stdout);
        i++;
    }
    system("stty echo");
    a[i] = '\0';
    strcpy(option->info.loginfo.password, a);
}



void log_opreation_c(int confd, struct sockaddr_in s_addr, struct options *option)
{
    char buff[10];
    memset(buff, 0, sizeof(buff));

    log_infomation(option);

    write(confd, option, sizeof(struct options));
    read(confd, buff, sizeof(buff));

    if (0 == strcmp(buff, "failed"))
    {
        printf("登录失败!!\n");
        sleep(2);
        return;
    }
    if (0 == strcmp(buff, "success"))
    {
        close(confd);
        chat_start_c(option, s_addr);
    }
    return;
}

void get_password(int confd, struct options *option)
{
    printf("请输入账号: ");
    scanf("%d", &option->info.pswdinfo.account);
    write(confd, option, sizeof(struct options));

    read(confd, option, sizeof(struct options));
    if (option->info.pswdinfo.flag == ACCOUNT_ERROR)
    {
        printf("账号输入错误:\n");
        return;
    }
    printf("密保问题: %s\n", option->info.pswdinfo.question);

    printf("请输入答案:");
    scanf("%s", option->info.pswdinfo.answer);
    write(confd, option, sizeof(struct options));
    read(confd, option, sizeof(struct options));
    read(confd, option, sizeof(struct options));
    read(confd, option, sizeof(struct options));

    if (option->info.pswdinfo.flag == ANSWER_ERROR)
    {
        printf("答案错误!!!\n");
        return;
    }
    printf("您的密码是:%s", option->info.pswdinfo.password);
}
#include "chat_c.h"

void *write_thread(void *arg)
{
    int bytes_send;
    int confd;
    int account;
    struct sockaddr_in s_addr;

    sqlite3 *pdb;

    struct msg_buff msg;
    memset(&msg, 0, sizeof(msg));
    struct cp args = *((struct cp *)arg);
    confd = args.confd;
    s_addr = args.s_addr;
    msg.my_account = args.account;
    msg.chat_mode = -1;

    sendto(confd, &msg, sizeof(msg), 0, (struct sockaddr *)&s_addr, sizeof(s_addr));

    do
    {
        printf("\n**************\n");
        printf("* 1.私    聊 *\n");
        printf("* 2.群    聊 *\n");
        printf("* 0.退    出 *\n");
        printf("**************\n");

        scanf("%d", &msg.chat_mode);
        getchar();
        switch (msg.chat_mode)
        {
        case STOO:
            say_to_one(confd, s_addr, &msg);
            break;
        case STOA:
            say_to_all(confd, s_addr, &msg);
            break;

        case EXIT:
            msg.chat_mode = EXIT;
            sendto(confd, &msg, sizeof(msg), 0, (struct sockaddr *)&s_addr, sizeof(s_addr));
            exit(0);
        default:
            break;
        };
    } while (msg.chat_mode);

    return NULL;
}

void say_to_all(int confd, struct sockaddr_in s_addr, struct msg_buff *msg)
{
    printf("请输入发送内容: ");
    scanf("%s", msg->message);
    sendto(confd, msg, sizeof(struct msg_buff), 0, (struct sockaddr *)&s_addr, sizeof(s_addr));
    return;
}

void say_to_one(int confd, struct sockaddr_in s_addr, struct msg_buff *msg)
{
    printf("请输入对方账号: ");
    scanf("%d", &msg->account);
    printf("请输入发送内容: ");
    scanf("%s", msg->message);
    sendto(confd, msg, sizeof(struct msg_buff), 0, (struct sockaddr *)&s_addr, sizeof(s_addr));
    return;
}


#include "chat_c.h"

void *read_thread(void *arg)
{
    int bytes_recv;
    int confd;
    int choice;
    int account;
    struct sockaddr_in s_addr;
    socklen_t s_len = sizeof(s_addr);

    sqlite3 *pdb;

    struct cp args = *((struct cp *)arg);

    confd = args.confd;
    s_addr = args.s_addr;
    struct msg_buff msg;

    while (1)
    {
        if ((bytes_recv = recvfrom(confd, &msg, sizeof(msg), 0, (struct sockaddr *)&s_addr, &s_len)) != sizeof(msg))
        {
            printf("bytes recv %d\n", bytes_recv);
            continue;
        }
        switch (msg.chat_mode)
        {
        case STOO:
            printf("one: recv mgs: %s\n", msg.message);
            break;
        case STOA:
            printf("all: recv mgs: %s\n", msg.message);
            break;

        default:
            break;
        }
    }
}

 服务器:

#ifndef _CHAT_S_H_
#define _CHAT_S_H_

#define _DEBUG_
#ifndef _DEBUG_
#define debug_msg(fmt, args...)
#else
#define debug_msg(fmt, args...) printf(fmt, ##args)
#endif

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <sys/epoll.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <fcntl.h>

#include <netinet/in.h>
#include <ctype.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <sqlite3.h>
#include <time.h>
#include <errno.h>
#include <pthread.h>

#include "thread_pool.h"

#define TCP_PORT 9000
#define UDP_PORT 9000
#define MAX_EVENTS 20
#define MSG_BUFFSIZE 100
#define MODE_SIZE 5
#define PASSWORD_SIZE 10
#define NICKNAME_SIZE 10
#define SC_PRO 10
#define ANSWER 10
#define EXP_SIZE 10
#define MALLOC_OK 0
#define MALLOC_NO -1
#define LOG_SUCCESS 1
#define SQL_SIZE 300
#define MAX_GROUP 20
#define MUSHIN -1
#define ACCOUNT_ERROR -1
#define ANSWER_ERROR -2

enum option
{
    EXIT,
    ERO,
    LOG,
    FGPD
};
enum chatmode
{
    STOA = 1,
    STOO,
    ONLINE
};


struct online
{
    struct sockaddr_in c_addr;
    int account;
    char nickname[NICKNAME_SIZE];
    struct online *next;
};

struct send_client
{
    struct sockaddr_in c_addr;             //发送端地址
    int account;
    char nickname[NICKNAME_SIZE];
    struct online *head;                   //链表存放所有上线客户端地址
};

struct arg
{
    int ufd;
    sqlite3 *pdb;
    struct  send_client cli;
    ThreadPool *pool;
};


// 注册信息
struct enrollinfo
{
    char nickname[NICKNAME_SIZE];
    char password[PASSWORD_SIZE];
    char sc_protect[SC_PRO];
    char answer[ANSWER];
    int account;
};

// 登录信息
struct logininfo
{
    int account;
    char password[PASSWORD_SIZE];
};

//
struct forgetpswdinfo
{
    char nickname[NICKNAME_SIZE];
    int account;
    char question[SC_PRO];
    char answer[ANSWER];
    char password[PASSWORD_SIZE];
    int flag;
};

// 选项消息包
struct options
{
    int option;
    union
    {
        struct enrollinfo eninfo;
        struct logininfo loginfo;
        struct forgetpswdinfo pswdinfo;
    } info;
};


// 聊天消息包
struct msg_buff
{
    int chat_mode;
    int my_account;
    int account;                     // 对方账号
    char nickname[NICKNAME_SIZE]; // 昵称
    char message[MSG_BUFFSIZE];
};


//创建链表
int cread_list(struct online **head);
int creat_node(struct online **node);
void insert_node(struct online *head, struct online *node);
void delete_node(struct send_client *cli);

// 创建数据库
int creat_sqlite(sqlite3 **pdb);

// 创建注册人员信息表及聊天记录数据库
int creat_table(sqlite3 *pdb, char *sql);
int creat_table_chat(sqlite3 *pdb);
int update_sqlite3_ip(int IP, int account, sqlite3 *pdb);
int update_sqlite3_statu(int account, sqlite3 *pdb);

// 超级用户创建
int super_root(sqlite3 *pdb);

// 创建tcp/udp套接口
void create_tcp_sever(int *sockfd, struct sockaddr_in *s_addr);
void create_udp_sever(int *sockfd, struct sockaddr_in *s_addr);
void addfd(int epollfd, int fd);

// 设置非阻塞
int setnonblock(int confd);

// 注册账号
void ero_account_s(const int *confd, struct enrollinfo *eninfo, sqlite3 *pdb);
// 登录
int log_operation_s(int confd, const struct logininfo *loginfo, sqlite3 *pdb);
// void chat_start_s(int confd, struct sockaddr_in c_addr, sqlite3 *pdb);
void chat_start_s(void *args);

// 套接口操作
int do_use_fd(int confd, const int *epollfd, sqlite3 *pdb);
void say_to_one(int confd, struct send_client *cli, struct msg_buff *msg);
void say_to_all(int confd, struct send_client *cli, struct msg_buff *msg);

void search_online(int ufd, struct send_client *cli, struct msg_buff *msg, sqlite3 *pdb);
void search_nickname(struct online *node,sqlite3 *pdb);
//找回密码
int search_password(int confd, struct options *option, sqlite3 *pdb);

#endif
#include "chat_s.h"

void create_udp_sever(int *sockfd, struct sockaddr_in *s_addr)
{
    socklen_t s_len = sizeof(struct sockaddr_in);
    *sockfd = socket(AF_INET, SOCK_DGRAM, 0);
    if (*sockfd < 0)
    {
        perror("sockfd creat error:");
        exit(1);
    }
    // 解决无法绑定问题, 重复绑定
    int opt = 1;
    setsockopt(*sockfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));

    bzero(s_addr, sizeof(struct sockaddr_in));
    (*s_addr).sin_family = AF_INET;
    (*s_addr).sin_port = htons(UDP_PORT);
    (*s_addr).sin_addr.s_addr = htonl(INADDR_ANY);

    if (bind(*sockfd, (struct sockaddr *)s_addr, s_len) == -1)
    {
        perror("bind error:");
        close(*sockfd);
        exit(1);
    }
}

void create_tcp_sever(int *sockfd, struct sockaddr_in *s_addr)
{
    socklen_t s_len = sizeof(struct sockaddr_in);
    *sockfd = socket(AF_INET, SOCK_STREAM, 0);
    if (*sockfd < 0)
    {
        perror("sockfd creat error:");
        exit(1);
    }
    // 解决无法绑定问题, 重复绑定
    int opt = 1;
    setsockopt(*sockfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));

    bzero(s_addr, sizeof(struct sockaddr_in));
    (*s_addr).sin_family = AF_INET;
    (*s_addr).sin_port = htons(TCP_PORT);
    (*s_addr).sin_addr.s_addr = htonl(INADDR_ANY);

    if (bind(*sockfd, (struct sockaddr *)s_addr, s_len) == -1)
    {
        perror("bind error:");
        close(*sockfd);
        exit(1);
    }
}

void addfd(int epollfd, int fd)
{
    struct epoll_event event;
    event.data.fd = fd;
    event.events = EPOLLIN | EPOLLET;
    epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &event);
    setnonblock(fd);
}

int setnonblock(int confd)
{
    int old_flags, new_flags;
    old_flags = fcntl(confd, F_GETFL);
    new_flags = old_flags | O_NONBLOCK;
    fcntl(confd, F_SETFL, new_flags);

    return old_flags;
}
int main()
{
    int listenfd;
    int udpfd;
    int confd;
    int epollfd;
    int nfds;
    int ret;

    sqlite3 *pdb;

    struct online *head;
    struct send_client cli;

    struct sockaddr_in s_addr_tcp;
    struct sockaddr_in s_addr_udp;
    struct sockaddr_in c_addr;
    socklen_t c_len = sizeof(c_addr);

    struct epoll_event ev, events[MAX_EVENTS];

    srand((unsigned)time(NULL));

    if (cread_list(&head) != MALLOC_OK)
    {
        printf("creat list error\n");
    }

    // 创建数据库
    if (creat_sqlite(&pdb) != SQLITE_OK)
    {
        sqlite3_close(pdb);
        exit(1);
    }
    // 创建注册人员啊信息表和聊天记录表
    if (creat_table_chat(pdb) != SQLITE_OK)
    {
        sqlite3_close(pdb);
        exit(1);
    }

    create_tcp_sever(&listenfd, &s_addr_tcp);
    listen(listenfd, 1024);

    create_udp_sever(&udpfd, &s_addr_udp);

    if ((epollfd = epoll_create(5)) == -1)
    {
        perror("epoll create error:");
        exit(1);
    }

    addfd(epollfd, listenfd);
    ThreadPool *pool = threadPoolCreate(3, 10, 100);

    while (1)
    {
        if ((nfds = epoll_wait(epollfd, events, MAX_EVENTS, -1)) == -1)
        {
            perror("epoll wait error:");
            break;
        }
        for (int i = 0; i < nfds; ++i)
        {
            if (events[i].data.fd == listenfd)
            {
                if ((confd = accept(listenfd, (struct sockaddr *)&c_addr, &c_len)) < 0)
                {
                    perror("accept error:");
                    close(listenfd);
                    exit(1);
                }

                setnonblock(confd);
                struct epoll_event event;
                event.data.fd = confd;
                event.events = EPOLLIN | EPOLLOUT | EPOLLET;
                if (epoll_ctl(epollfd, EPOLL_CTL_ADD, confd, &event) == -1)
                {
                    perror("epoll add error:");
                    close(listenfd);
                    close(confd);
                    exit(1);
                }
            }
            else
            {
                confd = events[i].data.fd;
                if (LOG_SUCCESS == do_use_fd(confd, &epollfd, pdb))
                {
                    struct arg args;
                    args.ufd = udpfd;
                    args.pdb = pdb;
                    args.cli.head = head;
                    args.pool = pool;
                    
                    threadPoolAdd(pool, chat_start_s, (void *)&args);
                }
            }
        }
    }
    threadPoolDestory(pool);

    close(listenfd);

    return 0;
}
#include "chat_s.h"

int do_use_fd(int confd, const int *epollfd, sqlite3 *pdb)
{
    struct options option;
    struct sockaddr_in c_addr;
    int c_len;
    c_len = sizeof(c_addr);

    int bytes_read;
    int account;

    option.option = -1;

    bytes_read = recvfrom(confd, &option, sizeof(option), 0, (struct sockaddr *)&c_addr, &c_len);
    if (bytes_read < 0)
    {
        perror("read option error:");
    }
    if (bytes_read == 0 && (errno != EAGAIN && errno != EWOULDBLOCK))
    {
        epoll_ctl(*epollfd, EPOLL_CTL_DEL, confd, NULL);
        return EXIT;
    }

    switch (option.option)
    {
    case ERO:
        ero_account_s(&confd, &option.info.eninfo, pdb);
        break;

    case LOG:
        if (log_operation_s(confd, &option.info.loginfo, pdb) != 0)
        {
            break;
        }

        return LOG_SUCCESS;

    case FGPD:
        search_password(confd, &option, pdb);
        break;

    case EXIT:

        debug_msg("do_use  %d offline\n", confd);
        epoll_ctl(*epollfd, EPOLL_CTL_DEL, confd, NULL);
        break;

    default:
        break;
    }
    return EXIT;
}

void ero_account_s(const int *confd, struct enrollinfo *eninfo, sqlite3 *pdb)
{
    char sql[SQL_SIZE];
    char *errmsg = NULL;

    debug_msg("ero_account  %d\n", __LINE__);

    eninfo->account = rand() % 100000 + 100000;
    sprintf(sql, "insert into chat_account (Account, Nickname, Password, Question, Answer) \
                                    values (%d, '%s', '%s', '%s', '%s');",
            eninfo->account, eninfo->nickname, eninfo->password, eninfo->sc_protect, eninfo->answer);

    while (sqlite3_exec(pdb, sql, NULL, NULL, &errmsg) != SQLITE_OK)
    {
        debug_msg("insert new account error:%s\n", errmsg);
        if (0 == strcmp(errmsg, "UNIQUE constraint failed: chat_count.Account"))
        {
            eninfo->account = rand() % 100000 + 100000;
            sprintf(sql, "insert into chat_account (Account, Nickname, Password, Question, Answer) \
                                    values (%d, '%s', '%s', '%s', '%s');",
                    eninfo->account, eninfo->nickname, eninfo->password, eninfo->sc_protect, eninfo->answer);
            sqlite3_free(errmsg);
        }
        else
        {
            debug_msg("insert new account error:%s\n", errmsg);
            return;
        }
    }
    write(*confd, &eninfo->account, sizeof(int));

    return;
}

int log_operation_s(int confd, const struct logininfo *loginfo, sqlite3 *pdb)
{
    char sql[SQL_SIZE];
    char **presult = NULL;
    char *errmsg = NULL;

    int row, col;

    sprintf(sql, "select Account,Password from chat_account       \
                    where Account = %d AND Password = '%s';",
            loginfo->account, loginfo->password);

    if (sqlite3_get_table(pdb, sql, &presult, &row, &col, &errmsg) != SQLITE_OK)
    {
        perror("log error:");
        sqlite3_free_table(presult);
        sqlite3_free(errmsg);
        return -1;
    }
    if (0 == row)
    {
        debug_msg("no account\n");
        write(confd, "failed", strlen("failed"));
        return -1;
    }
    else
    {
        write(confd, "success", strlen("success"));
    }

    sqlite3_free_table(presult);
    return 0;
}

int search_password(int confd, struct options *option, sqlite3 *pdb)
{
    char sql[SQL_SIZE];
    char **result = NULL;
    char *errmsg = NULL;

    int row, col;
    char password[PASSWORD_SIZE];

    sprintf(sql, "select Question from chat_account where account = %d;", option->info.pswdinfo.account);
    sqlite3_get_table(pdb, sql, &result, &row, &col, &errmsg);
    if (0 == row)
    {
        option->info.pswdinfo.flag = ACCOUNT_ERROR;
        write(confd, option, sizeof(struct options));
        return -1;
    }
    strcpy(option->info.pswdinfo.question, result[1]);
    write(confd, option, sizeof(struct options));
    sqlite3_free_table(result);
    result = NULL;

    sprintf(sql, "select Answer from chat_account where account = %d;", option->info.pswdinfo.account);
    sqlite3_get_table(pdb, sql, &result, &row, &col, &errmsg);

    if (0 == strcmp(option->info.pswdinfo.answer, result[1]))
    {
        sprintf(sql, "select Password from chat_account where account = %d;", option->info.pswdinfo.account);
        sqlite3_get_table(pdb, sql, &result, &row, &col, &errmsg);
        strcpy(option->info.pswdinfo.password, result[1]);
        write(confd, option, sizeof(struct options));
    }
    else
    {
        option->info.pswdinfo.flag = ANSWER_ERROR;
        write(confd, option, sizeof(struct options));
    }

    return SQLITE_OK;
}
#include "chat_s.h"

int creat_sqlite(sqlite3 **pdb)
{
    int ret;
    if((ret = sqlite3_open("chat_room.db", pdb)) != SQLITE_OK)
    {
        perror("creat sqlite error");
        exit(1);
    }
    return SQLITE_OK;
}


int creat_table_chat(sqlite3 *pdb)
{
    char *sql = "create table if not exists chat_account   (Account integer primary key,         \
                                                            Nickname text     NOT NULL,           \
                                                            Password text     NOT NULL,           \
                                                            Question text,                        \
                                                            Answer   text                        );";
    

    if(creat_table(pdb, sql) != SQLITE_OK)
    {
        debug_msg("creat table account error\n");
        return -1;
    }
    
    return SQLITE_OK;
}

int creat_table(sqlite3 *pdb, char *sql)
{
    int ret;

    char *errmsg = NULL;

    if((ret = sqlite3_exec(pdb, sql, NULL, NULL,&errmsg)) != SQLITE_OK)
    {
        perror("creat table error:");
        return -1;
    }

   
    return SQLITE_OK;
}


#include "chat_s.h"

void chat_start_s(void *arg)
{
    int bytes_read;

    struct msg_buff msg;
    struct arg args = *((struct arg *)arg);
    struct sockaddr_in c_addr;
    struct online *node;
    ThreadPool *pool;

    socklen_t c_len = sizeof(struct sockaddr_in);

    int ufd = args.ufd;
    sqlite3 *pdb = args.pdb;
    pool = args.pool;
    struct send_client cli = args.cli;

    while (1)
    {
        memset(&msg, 0, sizeof(msg));
        bytes_read = recvfrom(ufd, &msg, sizeof(msg), 0, (struct sockaddr *)&c_addr, &c_len);
        if (bytes_read < 0)
        {
            perror("read error");
        }
        if (msg.chat_mode == -1)
        {
            cli.c_addr = c_addr;
            cli.account = msg.my_account;
            creat_node(&node);
            node->c_addr = c_addr;
            node->account = msg.my_account;
            search_nickname(node, pdb);
            strcpy(cli.nickname, node->nickname);
            insert_node(cli.head, node);
            struct online *p;
            p = cli.head->next;
        }
        else
        {
            switch (msg.chat_mode)
            {
            case STOA:
                say_to_one(ufd, &cli, &msg);
            case STOO:
                say_to_all(ufd, &cli, &msg);
                break;
            case EXIT:
                delete_node(&cli);
                return;
            default:
                break;
            }
        }
    }
    return;
}

void say_to_all(int confd, struct send_client *cli, struct msg_buff *msg)
{
    char name[NICKNAME_SIZE];

    struct online *p;
    p = cli->head->next;
    while (p)
    {
        // 遍历链表
        if (p->account != msg->my_account)
            sendto(confd, msg, sizeof(struct msg_buff), 0, (struct sockaddr *)&p->c_addr, sizeof(struct sockaddr_in));

        p = p->next;
    }
}

void say_to_one(int confd, struct send_client *cli, struct msg_buff *msg)
{
    char name[NICKNAME_SIZE];

    struct online *p;
    p = cli->head->next;
    while (p && p->account != msg->account)
    {
        // 遍历链表
        p = p->next;
    }
    sendto(confd, msg, sizeof(struct msg_buff), 0, (struct sockaddr *)&p->c_addr, sizeof(struct sockaddr_in));
}


void search_nickname(struct online *node, sqlite3 *pdb)
{
    char sql[SQL_SIZE];
    char **result = NULL;
    char *errmsg = NULL;

    int row, col;

    sprintf(sql, "select Nickname from chat_account where Account = %d;", node->account);
    if (SQLITE_OK != sqlite3_get_table(pdb, sql, &result, &row, &col, &errmsg))
    {
        debug_msg("chat_s %d  %s\n", __LINE__, errmsg);
        sqlite3_free(errmsg);
        return;
    }
    if (0 == row)
    {
        debug_msg("no account\n");
        return;
    }
    else
    {
        strcpy(node->nickname, result[1]);
        printf("node nickname:%s\n",node->nickname);
    }

    sqlite3_free_table(result);
}
#include "thread_pool.h"

#define NUMBER 2

ThreadPool *threadPoolCreate(int min, int max, int queueSize)
{
    ThreadPool *pool = (ThreadPool *)malloc(sizeof(ThreadPool));
    do
    {
        if (pool == NULL)
        {
            printf("malloc threadpool fail!\n");
            break;
        }

        pool->threadIDS = (pthread_t *)malloc(sizeof(pthread_t) * max);
        if (pool->threadIDS == NULL)
        {
            printf("malloc threadIDs fail!\n");
            break;
        }

        memset(pool->threadIDS, 0, sizeof(pthread_t) * max);
        pool->minNum = min;
        pool->maxNum = max;
        pool->busyNum = 0;
        pool->liveNum = min; // 和最小个数相等
        pool->exitNum = 0;

        if (pthread_mutex_init(&pool->mutexBusy, NULL) != 0 ||
            pthread_mutex_init(&pool->mutexBusy, NULL) != 0 ||
            pthread_cond_init(&pool->notEmpty, NULL) != 0 ||
            pthread_cond_init(&pool->notFull, NULL) != 0)
        {
            printf("mutex or condition init fail!\n");
            break;
        }

        // 任务队列
        pool->taskQ = (Task *)malloc(sizeof(Task) * queueSize);
        pool->queueCapacity = queueSize;
        pool->queueSize = 0;
        pool->queueFront = 0;
        pool->queueRear = 0;

        pool->shutdown = 0;

        // 创建线程
        pthread_create(&pool->managerID, NULL, manager, pool);

        for (int i = 0; i < min; ++i)
        {
            pthread_create(&pool->threadIDS[i], NULL, worker, pool);
        }

        return pool;

    } while (0); // 目的是使用break关键字

    // 释放资源
    if (pool && pool->threadIDS)
        free(pool->threadIDS);
    if (pool &&pool->taskQ)
        free(pool->taskQ);
    if (pool)
        free(pool);

    return NULL;
}

void *worker(void *arg)
{
    ThreadPool *pool = (ThreadPool *)arg;
    while (1)
    {
        pthread_mutex_lock(&pool->mutexPool);
        // 当前的任务队列是否为空
        while (pool->queueSize == 0 && !pool->shutdown)
        {
            // 阻塞工作线程
            pthread_cond_wait(&pool->notEmpty, &pool->mutexPool);

            // 判断是不是要销毁线程
            if (pool->exitNum > 0)
            {
                pool->exitNum--;
                if (pool->liveNum > pool->minNum)
                {
                    pool->liveNum--;
                    pthread_mutex_unlock(&pool->mutexPool);
                    threadExit(pool);
                }
            }
        }

        // 判断线程池是否被关闭
        if (pool->shutdown)
        {
            pthread_mutex_unlock(&pool->mutexPool);
            threadExit(pool);
        }
        // 从任务队列中取出一个任务
        Task task;
        task.function = pool->taskQ[pool->queueFront].function;
        task.arg = pool->taskQ[pool->queueFront].arg;

        // 移动头节点
        pool->queueFront = (pool->queueFront + 1) % pool->queueCapacity;
        pool->queueSize--;

        // 解锁
        pthread_cond_signal(&pool->notFull);
        pthread_mutex_unlock(&pool->mutexPool);

        printf("thread %ld start working....\n", pthread_self());
        pthread_mutex_lock(&pool->mutexBusy);
        pool->busyNum++;
        pthread_mutex_unlock(&pool->mutexBusy);

        task.function(task.arg);
        // free(task.arg);                             //因为我的函数参数不是在堆上分配的空间,所以不需要释放
        // task.arg = NULL;

        printf("thread %ld end working....\n", pthread_self());
        pthread_mutex_lock(&pool->mutexBusy);
        pool->busyNum--;
        pthread_mutex_unlock(&pool->mutexBusy);
    }
    return NULL;
}

void *manager(void *arg)
{
    ThreadPool *pool = (ThreadPool *)arg;
    while (!pool->shutdown)
    {
        // 每隔3秒检测一次
        sleep(3);

        // 取出线程池中任务的数量和当前线程的数量
        pthread_mutex_lock(&pool->mutexPool);
        int queueSize = pool->queueSize;
        int liveNum = pool->liveNum;
        pthread_mutex_unlock(&pool->mutexPool);

        // 取出忙的线程数量
        pthread_mutex_lock(&pool->mutexBusy);
        int busyNum = pool->busyNum;
        pthread_mutex_unlock(&pool->mutexBusy);

        // 添加线程
        // 任务的个数 > 存活的线程个数 && 存活的线程数 < 最大的线程数
        if (queueSize > liveNum && liveNum < pool->maxNum)
        {
            pthread_mutex_lock(&pool->mutexPool);
            int counter = 0;
            for (int i = 0; i < pool->maxNum && counter < NUMBER && pool->liveNum < pool->maxNum; ++i)
            {
                if (pool->threadIDS[i] == 0)
                {
                    pthread_create(&pool->threadIDS[i], NULL, worker, pool);
                    counter++;
                    pool->liveNum++;
                }
            }
            pthread_mutex_unlock(&pool->mutexPool);
        }

        // 销毁线程
        // 忙的线程*2 < 存活的线程数 && 存活的线程数 > 最小线程数
        if (busyNum * 2 < liveNum && liveNum > pool->minNum)
        {
            pthread_mutex_lock(&pool->mutexBusy);
            pool->exitNum = NUMBER;
            pthread_mutex_unlock(&pool->mutexBusy);

            // 让公司的线程自杀
            for (int i = 0; i < NUMBER; ++i)
            {
                pthread_cond_signal(&pool->notEmpty);
            }
        }
    }

    return NULL;
}

void threadExit(ThreadPool *pool)
{
    pthread_t tid = pthread_self();
    for (int i = 0; i < pool->maxNum; ++i)
    {
        if (pool->threadIDS[i] == tid)
        {
            pool->threadIDS[i] = 0;
            printf("threadExit() called,%ld exiting...\n", tid);
            break;
        }
    }
    pthread_exit(NULL);
}

void threadPoolAdd(ThreadPool *pool, void (*func)(void *), void *arg)
{
    pthread_mutex_lock(&pool->mutexPool);
    while (pool->queueSize == pool->queueCapacity && !pool->shutdown)
    {
        // 阻塞生产者线程
        pthread_cond_wait(&pool->notFull, &pool->mutexPool);
    }
    if (pool->shutdown)
    {
        pthread_mutex_unlock(&pool->mutexPool);
        return;
    }
    // 添加任务
    pool->taskQ[pool->queueRear].function = func;
    pool->taskQ[pool->queueRear].arg = arg;
    pool->queueRear = (pool->queueRear + 1) % pool->queueCapacity;
    pool->queueSize++;

    pthread_cond_signal(&pool->notEmpty);
    pthread_mutex_unlock(&pool->mutexPool);
}

int threadPoolBusyNum(ThreadPool *pool)
{
    pthread_mutex_lock(&pool->mutexBusy);
    int busyNum = pool->busyNum;
    pthread_mutex_unlock(&pool->mutexBusy);
    return busyNum;
}

int threadPoolAliveNum(ThreadPool *pool)
{
    pthread_mutex_lock(&pool->mutexPool);
    int aliveNum = pool->liveNum;
    pthread_mutex_unlock(&pool->mutexPool);
    return aliveNum;
}

int threadPoolDestory(ThreadPool *pool)
{
    if (pool == NULL)
    {
        return -1;
    }

    // 关闭线程池
    pool->shutdown = 1;

    // 阻塞回收管理者线程
    pthread_join(pool->managerID, NULL);

    // 唤醒阻塞的消费者线程
    for (int i = 0; i < pool->liveNum; ++i)
    {
        pthread_cond_signal(&pool->notEmpty);
    }

    // 释放堆内存
    if (pool->taskQ)
    {
        free(pool->taskQ);
    }
    if (pool->threadIDS)
    {
        free(pool->threadIDS);
    }

    pthread_mutex_destroy(&pool->mutexPool);
    pthread_mutex_destroy(&pool->mutexBusy);
    pthread_cond_destroy(&pool->notEmpty);
    pthread_cond_destroy(&pool->notFull);
    free(pool);
    pool = NULL;
}
#ifndef _PTHREAD_H_
#define _PTHREAD_H_
#include <stdio.h>
#include <pthread.h>
#include <unistd.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>

// 任务结构体
typedef struct Task
{
    void (*function)(void *arg);
    void *arg;
} Task;

// 线程池结构体
struct ThreadPool
{
    // 任务队列
    Task *taskQ;
    int queueCapacity; // 容量
    int queueSize;     // 当前任务个数
    int queueFront;    // 队头 --- 取数据
    int queueRear;     // 队尾 --- 放数据

    pthread_t managerID;  // 管理者线程ID
    pthread_t *threadIDS; // 工作的线程ID
    int minNum;           // 最小线程数
    int maxNum;           // 最大线程数
    int busyNum;          // 忙的线程数
    int liveNum;          // 存活线程数
    int exitNum;          // 要销毁的线程数

    pthread_mutex_t mutexPool; // 锁整个的线程池
    pthread_mutex_t mutexBusy; // 锁busyNum变量

    pthread_cond_t notFull;  // 任务队列是否满
    pthread_cond_t notEmpty; // 任务队列是否空

    int shutdown; // 是否销毁线程池 销毁(1),不销毁(0)
};

typedef struct ThreadPool ThreadPool;
// 创建线程池 + 初始化
ThreadPool *threadPoolCreate(int min, int max, int queueSize);

// 销毁线程池
int threadPoolDestory(ThreadPool *pool);

// 给线程池添加任务
void threadPoolAdd(ThreadPool *pool, void (*func)(void *), void *arg);

// 获取线程池中工作的线程个数
int threadPoolBusyNum(ThreadPool *pool);

// 获取线程池中活着的线程个数
int threadPoolAliveNum(ThreadPool *pool);

/* ****************************************************************************************************************** */
void *worker(void *arg);
void *manager(void *arg);
void threadExit(ThreadPool *pool);



#endif

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值