tcp 解决short write问题

short write只存在非阻塞模式

什么是缓冲区?

参考动画图解 socket 缓冲区的那些事儿
在这里插入图片描述

什么是情况下会产生short write?

当发送缓冲区的剩余空间大小不足以容纳发送数据大小的时候,此时只会发送部分数据,并且生成错误码EAGAIN,此时就产生了short write现象,此时剩余的数据应该等缓冲区空间足够之后再次发送。

send/write:返回已经发送的字节数,errno为EAGAIN。

在这里插入图片描述

如何解决short write?(针对EPOLL模型 LT模式)

方法一: 将socket设置为阻塞模式。
方法二:维护自己的发送缓冲区,通过EPOLLONESHOTEPOLLOUT事件调用send/write发送数据。

思想:

  1. 封装do_send函数,内部实现一个环形缓冲区,当缓冲区空间不足容纳发送数据的时候返回false。当空间足够的时候返回true,并且注册EPOLLONESHOT|EPOLLOUT事件,NOTE:这里可能在任意线程发送,所以此时需要注意处理事件EPOLLOUT的线程安全性。
  2. 处理EPOLLOUT事件(需要线程同步),从环形缓冲区获取数据并通过send/wtire发送,如果发送成功则释放环形缓冲区空间,如果失败并且错误码为EAGAINEINTER则重试。当环形缓冲区数据为空的时候移除EPOLLOUT事件(避免cpu空转导致cpu浪费)。处理事件之后记得重置EPOLLONESHOT并且恢复EPOLLIN,因为在do_send的时候暂时取消了EPOLLIN

SocketContext.h

#pragma once
#include<mutex>
#include<string.h>
#include<sys/types.h>
#include<sys/socket.h>
#include<sys/epoll.h>

#define MAX_CIRCLE_BUF (10240)//10k
class CircleBuffer{
public:
    CircleBuffer(){
        m_data_size = 0;
        m_data_start_index = 0;
    }

    bool append(char *buffer,int size)
    {
        std::lock_guard<std::mutex> lock(m_mtx);
        if(size + m_data_size > MAX_CIRCLE_BUF)
        {
            return false;
        }

        int data_end_index = (m_data_start_index + m_data_size)%MAX_CIRCLE_BUF;

        if(data_end_index + size <= MAX_CIRCLE_BUF)
        {
            memcpy(m_buf + data_end_index,buffer,size);
            m_data_size += size;
            return true;
        }
        else
        {
            memcpy(m_buf + data_end_index,buffer,MAX_CIRCLE_BUF - data_end_index);
            int rest_len = size - (MAX_CIRCLE_BUF - data_end_index);
            memcpy(m_buf,buffer + (MAX_CIRCLE_BUF - data_end_index),rest_len);
            m_data_size += size;
            return true;
        }
    }

    void *get_buffer(int &size)
    {
        std::lock_guard<std::mutex> lock(m_mtx);
        int tid =::syscall(SYS_gettid);
        if(m_data_start_index + m_data_size <= MAX_CIRCLE_BUF)
        {
            size = m_data_size;
            return m_buf + m_data_start_index;
        }
        else
        {
            size = MAX_CIRCLE_BUF - m_data_start_index;
            return m_buf + m_data_start_index;
        }
        
    }

    void free_buf(int size)
    {
        std::lock_guard<std::mutex> lock(m_mtx);
        if (m_data_size < size)
        {
            throw std::runtime_error("circular buffer error");
        }
        
        m_data_size -= size;
        m_data_start_index = (m_data_start_index + size)%MAX_CIRCLE_BUF;
    }

    bool is_empty()
    {
        std::lock_guard<std::mutex> lock(m_mtx);
        return m_data_size <= 0;
    }
private:
    std::mutex m_mtx;
    char m_buf[MAX_CIRCLE_BUF];
    int m_data_size;
    int m_data_start_index;
};

class SocketContext{
public:
    SocketContext(int pollid,int sockid,bool is_listen_socket=false)
    {
        m_is_listen_socket = is_listen_socket;
        m_epoll_fd = pollid;
        m_sock_fd = sockid;
        addfd();
    }

    virtual ~SocketContext()
    {
        removefd();
        close(m_sock_fd);
        m_epoll_fd = -1;
        m_sock_fd = -1;
    }
    
    bool async_send(char* data,int size)
    {
        bool result = m_buf.append((char*)data,size);
        //可能在任意线程,发送的时候注册EPOLLOUT取消EPOLLIN,所以需要保证EPOLLOUT处理事件的线程安全
        resetOneshot(false,true);
        return result;
    } 

    std::shared_ptr<SocketContext> do_accept()
    {
        struct sockaddr_in client_addr = {0};
        socklen_t addr_len = sizeof(client_addr);
        int client_sock = accept(m_sock_fd,(struct sockaddr *)&client_addr,&addr_len);
        resetOneshot();       
        if(client_sock == -1)
        {
            return nullptr;
        }
        
        return std::shared_ptr<SocketContext>(new SocketContext(m_epoll_fd,client_sock));
    }   

    int do_recv()
    {
        char szBuffer[1024] = "";
        int count = read(m_sock_fd,szBuffer,1024);
        if(count <= 0)
        {
            return count;
        }
        resetOneshot();

        printf("socket[%d] recv data:%s\r\n",m_sock_fd,szBuffer);
        
        //这部分为测试代码            
        std::string strData = "hellow word abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVW0123456789 qqqqqqq";
       
        async_send((char *)strData.data(),strData.size());
        return count;
    }
    
    int do_send()
    {
        //keep thread safe
        std::unique_lock<std::mutex> lk(m_send_mtx,std::try_to_lock);
        if(!lk.owns_lock())
        {
            //do nothing
            return 0;
        }

        int data_size = 0;
        void *pdata = m_buf.get_buffer(data_size);
        if(data_size > 0)
        {
               
            int send_size = send(m_sock_fd,pdata,data_size,0);
            if(send_size > 0)
            {
                m_buf.free_buf(send_size);
            }
            else
            {
                printf("send failed:%d\r\n",errno);
            }
            //重新恢复EPOLLIN事件
            resetOneshot(true,true);
            return  send_size;
        }
        else
        {
        	//重新恢复EPOLLIN事件
            resetOneshot(true,false);
            return  0;
        }
    }

    int epoll_fd()
    {
        return m_epoll_fd;
    }

    int sock_fd()
    {
        return m_sock_fd;
    }
protected:
    void setnonblocking()
    {
        int flag = fcntl(m_sock_fd, F_GETFL);
        flag |= O_NONBLOCK;
        fcntl(m_sock_fd, F_SETFL, flag);
    }
    void addfd()
    {
        setnonblocking(); 
        struct epoll_event event;
        event.data.ptr = this;
        if(m_is_listen_socket)
        {
            event.events = EPOLLIN|EPOLLONESHOT;
        }
        else
        {
            event.events = EPOLLIN|EPOLLONESHOT|EPOLLOUT;
        }
        
        if(-1 == epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD,m_sock_fd,&event))
        {
            printf("addfd error:%s\r\n", strerror(errno)); 
        }
    }

    void removefd()
    {
        if(-1 == epoll_ctl(m_epoll_fd,EPOLL_CTL_DEL,m_sock_fd,NULL))
        {
            printf("removefd error:%s\r\n", strerror(errno)); 
        }
    }

    void resetOneshot(bool in_event = true,bool out_event = true)
    {
        int tid =::syscall(SYS_gettid);
        struct epoll_event event;
        event.data.ptr = this;
        event.events = EPOLLONESHOT;
        if(m_is_listen_socket)
        {
            if(in_event)
            {
                event.events |= EPOLLIN;
            }
            
        }
        else
        {
             if(in_event)
            {
                event.events |= EPOLLIN;
            }

            if(out_event)
            {
                event.events |= EPOLLOUT;
            }
        }
        if(-1 == epoll_ctl(m_epoll_fd, EPOLL_CTL_MOD,m_sock_fd,&event))
        {
            printf("modifyfd error:%s\r\n", strerror(errno)); 
        }
    }
private:
    bool m_is_listen_socket;
    int m_sock_fd;
    int m_epoll_fd;
    std::mutex m_send_mtx;
    CircleBuffer m_buf;
};

main.cpp

#include<stdio.h>
#include<iostream>
#include<sys/socket.h>
#include<fcntl.h>
#include<unistd.h>
#include<sys/epoll.h>
#include<netinet/in.h>
#include<arpa/inet.h>
#include<errno.h>
#include<string.h>
#include<mutex>
#include<thread>
#include <sys/syscall.h>
#include <signal.h>
#include<map>
#include"SockContext.h"


std::map<int,std::shared_ptr<SocketContext> > g_cli_list;
int handle_accepter_event(int thread_id,int epoll_id,int listen_sock)
{
    const int MAX_EVENTS = 10;
    struct epoll_event events[MAX_EVENTS];
    while (1)
    {
        int ret = epoll_wait(epoll_id,events,MAX_EVENTS,1000);
        if(ret == -1)
        {
            if(errno == EINTR)
            {
                continue;
            }
            printf("[%d]epoll_wait error:%s\r\n",thread_id, strerror(errno)); 
            return -1;
        }

        for (size_t i = 0; i < ret; i++)
        {
            SocketContext *pctx = (SocketContext *)events[i].data.ptr;
            //listen_sock
            if (pctx->sock_fd()== listen_sock) 
            {   
                auto client = pctx->do_accept();
                if(client!= nullptr)
                {
                    g_cli_list[client->sock_fd()] = client;
                    printf("socket[%d] connected\r\n",client->sock_fd()); 
                }
                else
                {
                    printf("socket connected error\r\n"); 
                }
            }
            else if(events[i].events & EPOLLIN)
            {
                int nread_count = pctx->do_recv();
                if(nread_count <= 0)
                {
                    printf("[%d]socket[%d] disconnected\r\n",thread_id, pctx->sock_fd());
                    g_cli_list.erase(pctx->sock_fd());
                }
            }
            else if(events[i].events & EPOLLOUT)
            {
                pctx->do_send();
            }
        }
        
    }
}

void handle_signal(int signal)
{
    if(signal == SIGPIPE)
    {
        printf("recv sig pipe\r\n");
    }
}

int main()
{
    signal(SIGPIPE,handle_signal);
    int listen_sock = socket(AF_INET, SOCK_STREAM, 0);

    if(listen_sock == -1)
    {
        printf("socket error:%s\r\n", strerror(errno)); 
        return -1;
    }

    int reuse = 1;
    setsockopt(listen_sock,SOL_SOCKET,SO_REUSEADDR,&reuse,sizeof(reuse));
    struct sockaddr_in ser_addr = {0};
    ser_addr.sin_family = AF_INET;
    ser_addr.sin_port = htons(6360);
    ser_addr.sin_addr.s_addr = INADDR_ANY;
    
    if(-1 == bind(listen_sock, (struct sockaddr *)&ser_addr, sizeof(ser_addr)))
    {
        printf("bind socket error:%s\r\n", strerror(errno)); 
        return -1;
    }

    if(-1 == listen(listen_sock,5))
    {
        printf("listen socket error:%s\r\n", strerror(errno)); 
        return -1;
    }

    int epoll_id = epoll_create(5);
    if(epoll_id == -1)
    {
        printf("epoll_create error:%s\r\n", strerror(errno)); 
        return -1;
    }

    g_cli_list[listen_sock] = std::shared_ptr<SocketContext>(new SocketContext(epoll_id,listen_sock,true));
    std::thread t(handle_accepter_event,1,epoll_id,listen_sock);
    std::thread t2(handle_accepter_event,2,epoll_id,listen_sock);
    t.join();
    t2.join();
   
    
    printf("hellow word\r\n");
    return 0;
}

上述代码,可以完美的解决short write问题,因为环形缓冲区保证了发送数据的完整性。但是上述代码在极端情况下可能导致读不用。因为发送数据的时候取消EPOOLIN事件,在下一次的EPOOLOUT事件中才恢复,如果在极端情况下,可能导致写缓冲区一直不可写,从而不会触发EPOOLOUT,导致EPOOLIN事件没办法恢复。所以需要进行代码优化

发送数据的时候,如果环形缓冲区没有数据,此时直接调用系统的send进行发送。如果send只发送了部分数据,那么将剩下的数据加入缓冲区,通过触发EPOOLOUT来进行数据发送。这个过程需要保证线程安全性,防止数据混乱

bool async_send(char* data,int size)
{
		//需要加锁,防止数据错乱
        std::lock_guard<std::mutex> lock(m_async_send_mtx);
        int send_ok_bytes = 0;
     	
     	//如果缓冲区为空,则说明没有short write导致的空数据,直接进行发送,否则加入缓冲区
        if(m_buf.is_empty())
        {
            send_ok_bytes = send(m_sock_fd, data, size,MSG_DONTWAIT);
            if(send_ok_bytes == size)
            {
                return true;
            }

            (send_ok_bytes >= 0) ? send_ok_bytes : send_ok_bytes = 0;
        }
        bool result = m_buf.append((char*)data + send_ok_bytes,size - send_ok_bytes);
        resetOneshot(false,true);
        return result;
} 

这样可以很大程度缓解上述所说的极端情况,但是不能100%解决。如果想要100%解决那么就不能取消EPOOLIN,此时就需要保证EPOOLIN的线程安全性,此时可以参考EPOOLOUT的加锁方式,处理数据安全性。

int do_recv()
    {
         //keep thread safe
        std::unique_lock<std::mutex> lk(m_dorecv_mtx,std::try_to_lock);
        if(!lk.owns_lock())
        {
            //这里可能会疑问为什么不是返回0,因为read返回的0表示socket断开链接了,为了避免外部误认为链接断开
            //由于这里仅仅是demo,外部没有统计接收字节数,所以此时返回1是没有问题
            //为了更你好的使用,建议修改一下原型
            return 1;
        }

        char szBuffer[1024] = "";
        int count = read(m_sock_fd,szBuffer,1024);
        if(count <= 0)
        {
            return count;
        }
        resetOneshot();

        printf("socket[%d] recv data:%s\r\n",m_sock_fd,szBuffer);
                    
        std::string strData = "hellow word abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVW0123456789 qqqqqqq";
       
       int64_t total_send = 0;
       while (async_send((char *)strData.data(),strData.size()))
       {
           total_send += strData.size();
           printf("send total:%lld  data[%d]\r\n",total_send,strData.size());
       }
       
        
        return count;
    }

 bool async_send(char* data,int size)
    {
    	...
		//此时这里修改为true
        resetOneshot(true,true);
        return result;
    } 

此时关于short write产生的问题就可以完美解决了。但是此时还有另外一个问题,环形缓冲区的大小确立。因为在服务器上使用,如果太大当连接数上来之后内存就爆了。一般不建议太大,因为当发送数据阻塞的时候,端收到的数据也将是滞后的数据,可能意义并不是很大。一般建议平均缓存5-10个包的大小即可。但是也可以使用动态内存,事先不预分配。关于动态内存此处不在本文讨论范围。

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值