早在19年5月就在某站上看到sylar的视频了,一直认为这是一个非常不错的视频。
由于本人一直是自学编程,基础不扎实,也没有任何人的督促,没能坚持下去。
每每想起倍感惋惜,遂提笔再续前缘。
为了能更好的看懂sylar,本套笔记会分两步走,每个系统都会分为两篇博客。
分别是【知识储备篇】和【代码分析篇】
(ps:纯粹做笔记的形式给自己记录下,欢迎大家评论,不足之处请多多赐教)
QQ交流群:957100923
B站视频:https://b23.tv/YusP39I
Socket模块-代码分析篇
一、说在前面
由前文 【知识储备篇】 可以知道对于原生 Socket的封装的必要性。那么这一篇就来说一下如何进行原生 Socket 的封装。聊聊他的思路和具体代码实现。
二、Socket 构造需要什么
创建一个socket离不开这三个点:
1.协议簇 : UDP、TCP
2.互联网协议 : IPV4、IPV6
3.socket的类型 : 流套接字(SOCK_STREAM)、数据包套接字(SOCK_DGRAM)、原始套接字(SOCK_RAW)
socket.h
/**
* @brief Socket构造函数
* @param[in] family 协议簇
* @param[in] type 类型
* @param[in] protocol 协议
*/
Socket(int family, int type, int protocol = 0);
socket.cc
Socket::Socket(int family, int type, int protocol)
: m_sock(-1)
, m_family(family)
, m_type(type)
, m_protocol(protocol)
, m_isConnected(false) {
}
三、对Socket创建而言我们要如何封装?
我们在封装原生 Socket 之前一定要先明确有哪些类别的Socket我们需要封装进来
这里我列了一下,分别是以下几个的组合:
方法名称 | 协议簇【网络传输协议】 | 互联网协议 | 备注 |
---|---|---|---|
CreateUDPSocket | UPD | IPV4 | 创建一个IPV4的UDP socket |
CreateTCPSocket | TCP | IPV4 | 创建一个IPV4的TCP socket |
CreateUDPSocket6 | UDP | IPV6 | 创建一个IPV6的UDP socket |
CreateTCPSocket6 | TCP | IPV6 | 创建一个IPV6的TCP socket |
CreateUnixUDPSocket | UDP | Unix | 创建一个Unix本机的UDP socket |
CreateUnixTCPSocket | TCP | Unix | 创建一个Unix本机的TCP socket |
以上就是创建对应网络协议和互联网协议的6种类型的Socket。
由于在 【Address模块篇】 中已经封装了Address,而socket的创建也可以通过Address,所以我们补充两个通过 address 创建socket的方法。
方法名称 | 协议簇【网络传输协议】 | 地址 | 备注 |
---|---|---|---|
CreateTCP | UPD | sylar::Address::ptr | 基于address创建一个UDP socket |
CreateUDP | TCP | sylar::Address::ptr | 基于address创建一个TCP socket |
以下是对应代码:
socket.h
/**
* @brief 创建TCP Socket(满足地址类型)
* @param[in] address 地址
*/
static Socket::ptr CreateTCP(sylar::Address::ptr address);
/**
* @brief 创建UDP Socket(满足地址类型)
* @param[in] address 地址
*/
static Socket::ptr CreateUDP(sylar::Address::ptr address);
/**
* @brief 创建IPv4的TCP Socket
*/
static Socket::ptr CreateTCPSocket();
/**
* @brief 创建IPv4的UDP Socket
*/
static Socket::ptr CreateUDPSocket();
/**
* @brief 创建IPv6的TCP Socket
*/
static Socket::ptr CreateTCPSocket6();
/**
* @brief 创建IPv6的UDP Socket
*/
static Socket::ptr CreateUDPSocket6();
/**
* @brief 创建Unix的TCP Socket
*/
static Socket::ptr CreateUnixTCPSocket();
/**
* @brief 创建Unix的UDP Socket
*/
static Socket::ptr CreateUnixUDPSocket();
socket.cc
Socket::ptr Socket::CreateTCP(sylar::Address::ptr address) {
Socket::ptr sock(new Socket(address->getFamily(), TCP, 0));
return sock;
}
Socket::ptr Socket::CreateUDP(sylar::Address::ptr address) {
Socket::ptr sock(new Socket(address->getFamily(), UDP, 0));
sock->newSock();
sock->m_isConnected = true;
return sock;
}
Socket::ptr Socket::CreateTCPSocket() {
Socket::ptr sock(new Socket(IPv4, TCP, 0));
return sock;
}
Socket::ptr Socket::CreateUDPSocket() {
Socket::ptr sock(new Socket(IPv4, UDP, 0));
sock->newSock();
sock->m_isConnected = true;
return sock;
}
Socket::ptr Socket::CreateTCPSocket6() {
Socket::ptr sock(new Socket(IPv6, TCP, 0));
return sock;
}
Socket::ptr Socket::CreateUDPSocket6() {
Socket::ptr sock(new Socket(IPv6, UDP, 0));
sock->newSock();
sock->m_isConnected = true;
return sock;
}
Socket::ptr Socket::CreateUnixTCPSocket() {
Socket::ptr sock(new Socket(UNIX, TCP, 0));
return sock;
}
Socket::ptr Socket::CreateUnixUDPSocket() {
Socket::ptr sock(new Socket(UNIX, UDP, 0));
return sock;
}
四、对于服务端而言我们要封装什么方法?
1.bind方法的封装:
作为服务端,我们创建一个socket后需要绑定一个地址来进行监听,那么我们需要封装绑定地址的方法。
socket.h
/**
* @brief 绑定地址
* @param[in] addr 地址
* @return 是否绑定成功
*/
virtual bool bind(const Address::ptr addr);
socket.cc
bool Socket::bind(const Address::ptr addr) {
m_localAddress = addr;
if (!isValid()) {
newSock();
if (SYLAR_UNLIKELY(!isValid())) {
return false;
}
}
if (SYLAR_UNLIKELY(addr->getFamily() != m_family)) {
SYLAR_LOG_ERROR(g_logger) << "bind sock.family("
<< m_family << ") addr.family(" << addr->getFamily()
<< ") not equal, addr=" << addr->toString();
return false;
}
UnixAddress::ptr uaddr = std::dynamic_pointer_cast<UnixAddress>(addr);
if (uaddr) {
Socket::ptr sock = Socket::CreateUnixTCPSocket();
if (sock->connect(uaddr)) {
return false;
} else {
sylar::FSUtil::Unlink(uaddr->getPath(), true);
}
}
if (::bind(m_sock, addr->getAddr(), addr->getAddrLen())) {
SYLAR_LOG_ERROR(g_logger) << "bind error errrno=" << errno
<< " errstr=" << strerror(errno);
return false;
}
getLocalAddress();
return true;
}
2.listen方法的封装:
将socket放入监听队列中
socket.h
/**
* @brief 监听socket
* @param[in] backlog 未完成连接队列的最大长度
* @result 返回监听是否成功
* @pre 必须先 bind 成功
*/
virtual bool listen(int backlog = SOMAXCONN);
socket.cc
bool Socket::listen(int backlog) {
if (!isValid()) {
SYLAR_LOG_ERROR(g_logger) << "listen error sock=-1";
return false;
}
if (::listen(m_sock, backlog)) {
SYLAR_LOG_ERROR(g_logger) << "listen error errno=" << errno
<< " errstr=" << strerror(errno);
return false;
}
return true;
}
3.accept方法的封装:
监听是否有客户端链接请求
socket.h
/**
* @brief 接收connect链接
* @return 成功返回新连接的socket,失败返回nullptr
* @pre Socket必须 bind , listen 成功
*/
virtual Socket::ptr accept();
socket.cc
Socket::ptr Socket::accept() {
Socket::ptr sock(new Socket(m_family, m_type, m_protocol));
int newsock = ::accept(m_sock, nullptr, nullptr);
if (newsock == -1) {
SYLAR_LOG_ERROR(g_logger) << "accept(" << m_sock << ") errno="
<< errno << " errstr=" << strerror(errno);
return nullptr;
}
if (sock->init(newsock)) {
return sock;
}
return nullptr;
}
五、对于客户端而言我们要封装什么?
1.connect方法的封装:
客户端需要主动向服务端发起链接请求(此方法会对应服务端的accept方法)
socket.h
/**
* @brief 连接地址
* @param[in] addr 目标地址
* @param[in] timeout_ms 超时时间(毫秒)
*/
virtual bool connect(const Address::ptr addr, uint64_t timeout_ms = -1);
socket.cc
bool Socket::connect(const Address::ptr addr, uint64_t timeout_ms) {
m_remoteAddress = addr;
if (!isValid()) {
newSock();
if (SYLAR_UNLIKELY(!isValid())) {
return false;
}
}
if (SYLAR_UNLIKELY(addr->getFamily() != m_family)) {
SYLAR_LOG_ERROR(g_logger) << "connect sock.family("
<< m_family << ") addr.family(" << addr->getFamily()
<< ") not equal, addr=" << addr->toString();
return false;
}
if (timeout_ms == (uint64_t)-1) {
if (::connect(m_sock, addr->getAddr(), addr->getAddrLen())) {
SYLAR_LOG_ERROR(g_logger) << "sock=" << m_sock << " connect(" << addr->toString()
<< ") error errno=" << errno << " errstr=" << strerror(errno);
close();
return false;
}
} else {
if (::connect_with_timeout(m_sock, addr->getAddr(), addr->getAddrLen(), timeout_ms)) {
SYLAR_LOG_ERROR(g_logger) << "sock=" << m_sock << " connect(" << addr->toString()
<< ") timeout=" << timeout_ms << " error errno="
<< errno << " errstr=" << strerror(errno);
close();
return false;
}
}
m_isConnected = true;
getRemoteAddress();
getLocalAddress();
return true;
}
六、客户端与服务端 Socket 都需要考虑的封装有哪些?
1.send方法:
由于情况比较多,我们可以参考以下表格
方法 | 数据类型 | 是否指定地址 |
---|---|---|
virtual int send(const void *buffer, size_t length, int flags = 0); | 普通数据 | 否 |
virtual int send(const iovec *buffers, size_t length, int flags = 0); | 普通数据 | 是 |
virtual int sendTo(const void *buffer, size_t length, const Address::ptr to, int flags = 0); | IO数据 | 否 |
virtual int sendTo(const iovec *buffers, size_t length, const Address::ptr to, int flags = 0); | IO数据 | 是 |
socket.h
/**
* @brief 发送数据
* @param[in] buffer 待发送数据的内存
* @param[in] length 待发送数据的长度
* @param[in] flags 标志字
* @return
* @retval >0 发送成功对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int send(const void *buffer, size_t length, int flags = 0);
/**
* @brief 发送数据
* @param[in] buffers 待发送数据的内存(iovec数组)
* @param[in] length 待发送数据的长度(iovec长度)
* @param[in] flags 标志字
* @return
* @retval >0 发送成功对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int send(const iovec *buffers, size_t length, int flags = 0);
/**
* @brief 发送数据
* @param[in] buffer 待发送数据的内存
* @param[in] length 待发送数据的长度
* @param[in] to 发送的目标地址
* @param[in] flags 标志字
* @return
* @retval >0 发送成功对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int sendTo(const void *buffer, size_t length, const Address::ptr to, int flags = 0);
/**
* @brief 发送数据
* @param[in] buffers 待发送数据的内存(iovec数组)
* @param[in] length 待发送数据的长度(iovec长度)
* @param[in] to 发送的目标地址
* @param[in] flags 标志字
* @return
* @retval >0 发送成功对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int sendTo(const iovec *buffers, size_t length, const Address::ptr to, int flags = 0);
socket.cc
int Socket::send(const void *buffer, size_t length, int flags) {
if (isConnected()) {
return ::send(m_sock, buffer, length, flags);
}
return -1;
}
int Socket::send(const iovec *buffers, size_t length, int flags) {
if (isConnected()) {
msghdr msg;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = (iovec *)buffers;
msg.msg_iovlen = length;
return ::sendmsg(m_sock, &msg, flags);
}
return -1;
}
int Socket::sendTo(const void *buffer, size_t length, const Address::ptr to, int flags) {
if (isConnected()) {
return ::sendto(m_sock, buffer, length, flags, to->getAddr(), to->getAddrLen());
}
return -1;
}
int Socket::sendTo(const iovec *buffers, size_t length, const Address::ptr to, int flags) {
if (isConnected()) {
msghdr msg;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = (iovec *)buffers;
msg.msg_iovlen = length;
msg.msg_name = to->getAddr();
msg.msg_namelen = to->getAddrLen();
return ::sendmsg(m_sock, &msg, flags);
}
return -1;
}
2.recv方法:
对应send方法,recv也有四个
方法 | 数据类型 | 是否指定地址 |
---|---|---|
virtual int recv(void *buffer, size_t length, int flags = 0); | 普通数据 | 否 |
virtual int recv(iovec *buffers, size_t length, int flags = 0); | 普通数据 | 是 |
virtual int recvFrom(void *buffer, size_t length, Address::ptr from, int flags = 0); | IO数据 | 否 |
virtual int recvFrom(iovec *buffers, size_t length, Address::ptr from, int flags = 0); | IO数据 | 是 |
socket.h
/**
*@brief 接受数据
* @param[out] buffer 接收数据的内存
* @param[in] length 接收数据的内存大小
* @param[in] flags 标志字
* @return
* @retval >0 接收到对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int recv(void *buffer, size_t length, int flags = 0);
/**
* @brief 接受数据
* @param[out] buffers 接收数据的内存(iovec数组)
* @param[in] length 接收数据的内存大小(iovec数组长度)
* @param[in] flags 标志字
* @return
* @retval >0 接收到对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int recv(iovec *buffers, size_t length, int flags = 0);
/**
* @brief 接受数据
* @param[out] buffer 接收数据的内存
* @param[in] length 接收数据的内存大小
* @param[out] from 发送端地址
* @param[in] flags 标志字
* @return
* @retval >0 接收到对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int recvFrom(void *buffer, size_t length, Address::ptr from, int flags = 0);
/**
* @brief 接受数据
* @param[out] buffers 接收数据的内存(iovec数组)
* @param[in] length 接收数据的内存大小(iovec数组长度)
* @param[out] from 发送端地址
* @param[in] flags 标志字
* @return
* @retval >0 接收到对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int recvFrom(iovec *buffers, size_t length, Address::ptr from, int flags = 0);
socket.cc
int Socket::recv(void *buffer, size_t length, int flags) {
if (isConnected()) {
return ::recv(m_sock, buffer, length, flags);
}
return -1;
}
int Socket::recv(iovec *buffers, size_t length, int flags) {
if (isConnected()) {
msghdr msg;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = (iovec *)buffers;
msg.msg_iovlen = length;
return ::recvmsg(m_sock, &msg, flags);
}
return -1;
}
int Socket::recvFrom(void *buffer, size_t length, Address::ptr from, int flags) {
if (isConnected()) {
socklen_t len = from->getAddrLen();
return ::recvfrom(m_sock, buffer, length, flags, from->getAddr(), &len);
}
return -1;
}
int Socket::recvFrom(iovec *buffers, size_t length, Address::ptr from, int flags) {
if (isConnected()) {
msghdr msg;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = (iovec *)buffers;
msg.msg_iovlen = length;
msg.msg_name = from->getAddr();
msg.msg_namelen = from->getAddrLen();
return ::recvmsg(m_sock, &msg, flags);
}
return -1;
}
3.close方法:
这个方法是必然需要的
socket.h
/**
* @brief 关闭socket
*/
virtual bool close();
socket.cc
bool Socket::close() {
if (!m_isConnected && m_sock == -1) {
return true;
}
m_isConnected = false;
if (m_sock != -1) {
::close(m_sock);
m_sock = -1;
}
return false;
}
至此,最主要的方法已经全部列出了,其他就是对于Socket封装的补充方法和具体实现时提炼的方法。
那么这些方法在下面的全部代码中会有列出。
七、完整代码
socket.h
#ifndef __SYLAR_SOCKET_H__
#define __SYLAR_SOCKET_H__
#include <memory>
#include <netinet/tcp.h>
#include <sys/types.h>
#include <sys/socket.h>
#include "address.h"
#include "noncopyable.h"
namespace sylar {
/**
* @brief Socket封装类
*/
class Socket : public std::enable_shared_from_this<Socket>, Noncopyable {
public:
typedef std::shared_ptr<Socket> ptr;
typedef std::weak_ptr<Socket> weak_ptr;
/**
* @brief Socket类型
*/
enum Type {
/// TCP类型
TCP = SOCK_STREAM,
/// UDP类型
UDP = SOCK_DGRAM
};
/**
* @brief Socket协议簇
*/
enum Family {
/// IPv4 socket
IPv4 = AF_INET,
/// IPv6 socket
IPv6 = AF_INET6,
/// Unix socket
UNIX = AF_UNIX,
};
/**
* @brief 创建TCP Socket(满足地址类型)
* @param[in] address 地址
*/
static Socket::ptr CreateTCP(sylar::Address::ptr address);
/**
* @brief 创建UDP Socket(满足地址类型)
* @param[in] address 地址
*/
static Socket::ptr CreateUDP(sylar::Address::ptr address);
/**
* @brief 创建IPv4的TCP Socket
*/
static Socket::ptr CreateTCPSocket();
/**
* @brief 创建IPv4的UDP Socket
*/
static Socket::ptr CreateUDPSocket();
/**
* @brief 创建IPv6的TCP Socket
*/
static Socket::ptr CreateTCPSocket6();
/**
* @brief 创建IPv6的UDP Socket
*/
static Socket::ptr CreateUDPSocket6();
/**
* @brief 创建Unix的TCP Socket
*/
static Socket::ptr CreateUnixTCPSocket();
/**
* @brief 创建Unix的UDP Socket
*/
static Socket::ptr CreateUnixUDPSocket();
/**
* @brief Socket构造函数
* @param[in] family 协议簇
* @param[in] type 类型
* @param[in] protocol 协议
*/
Socket(int family, int type, int protocol = 0);
/**
* @brief 析构函数
*/
virtual ~Socket();
/**
* @brief 获取发送超时时间(毫秒)
*/
int64_t getSendTimeout();
/**
* @brief 设置发送超时时间(毫秒)
*/
void setSendTimeout(int64_t v);
/**
* @brief 获取接受超时时间(毫秒)
*/
int64_t getRecvTimeout();
/**
* @brief 设置接受超时时间(毫秒)
*/
void setRecvTimeout(int64_t v);
/**
* @brief 获取sockopt @see getsockopt
*/
bool getOption(int level, int option, void *result, socklen_t *len);
/**
* @brief 获取sockopt模板 @see getsockopt
*/
template <class T>
bool getOption(int level, int option, T &result) {
socklen_t length = sizeof(T);
return getOption(level, option, &result, &length);
}
/**
* @brief 设置sockopt @see setsockopt
*/
bool setOption(int level, int option, const void *result, socklen_t len);
/**
* @brief 设置sockopt模板 @see setsockopt
*/
template <class T>
bool setOption(int level, int option, const T &value) {
return setOption(level, option, &value, sizeof(T));
}
/**
* @brief 接收connect链接
* @return 成功返回新连接的socket,失败返回nullptr
* @pre Socket必须 bind , listen 成功
*/
virtual Socket::ptr accept();
/**
* @brief 绑定地址
* @param[in] addr 地址
* @return 是否绑定成功
*/
virtual bool bind(const Address::ptr addr);
/**
* @brief 连接地址
* @param[in] addr 目标地址
* @param[in] timeout_ms 超时时间(毫秒)
*/
virtual bool connect(const Address::ptr addr, uint64_t timeout_ms = -1);
virtual bool reconnect(uint64_t timeout_ms = -1);
/**
* @brief 监听socket
* @param[in] backlog 未完成连接队列的最大长度
* @result 返回监听是否成功
* @pre 必须先 bind 成功
*/
virtual bool listen(int backlog = SOMAXCONN);
/**
* @brief 关闭socket
*/
virtual bool close();
/**
* @brief 发送数据
* @param[in] buffer 待发送数据的内存
* @param[in] length 待发送数据的长度
* @param[in] flags 标志字
* @return
* @retval >0 发送成功对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int send(const void *buffer, size_t length, int flags = 0);
/**
* @brief 发送数据
* @param[in] buffers 待发送数据的内存(iovec数组)
* @param[in] length 待发送数据的长度(iovec长度)
* @param[in] flags 标志字
* @return
* @retval >0 发送成功对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int send(const iovec *buffers, size_t length, int flags = 0);
/**
* @brief 发送数据
* @param[in] buffer 待发送数据的内存
* @param[in] length 待发送数据的长度
* @param[in] to 发送的目标地址
* @param[in] flags 标志字
* @return
* @retval >0 发送成功对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int sendTo(const void *buffer, size_t length, const Address::ptr to, int flags = 0);
/**
* @brief 发送数据
* @param[in] buffers 待发送数据的内存(iovec数组)
* @param[in] length 待发送数据的长度(iovec长度)
* @param[in] to 发送的目标地址
* @param[in] flags 标志字
* @return
* @retval >0 发送成功对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int sendTo(const iovec *buffers, size_t length, const Address::ptr to, int flags = 0);
/**
* @brief 接受数据
* @param[out] buffer 接收数据的内存
* @param[in] length 接收数据的内存大小
* @param[in] flags 标志字
* @return
* @retval >0 接收到对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int recv(void *buffer, size_t length, int flags = 0);
/**
* @brief 接受数据
* @param[out] buffers 接收数据的内存(iovec数组)
* @param[in] length 接收数据的内存大小(iovec数组长度)
* @param[in] flags 标志字
* @return
* @retval >0 接收到对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int recv(iovec *buffers, size_t length, int flags = 0);
/**
* @brief 接受数据
* @param[out] buffer 接收数据的内存
* @param[in] length 接收数据的内存大小
* @param[out] from 发送端地址
* @param[in] flags 标志字
* @return
* @retval >0 接收到对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int recvFrom(void *buffer, size_t length, Address::ptr from, int flags = 0);
/**
* @brief 接受数据
* @param[out] buffers 接收数据的内存(iovec数组)
* @param[in] length 接收数据的内存大小(iovec数组长度)
* @param[out] from 发送端地址
* @param[in] flags 标志字
* @return
* @retval >0 接收到对应大小的数据
* @retval =0 socket被关闭
* @retval <0 socket出错
*/
virtual int recvFrom(iovec *buffers, size_t length, Address::ptr from, int flags = 0);
/**
* @brief 获取远端地址
*/
Address::ptr getRemoteAddress();
/**
* @brief 获取本地地址
*/
Address::ptr getLocalAddress();
/**
* @brief 获取协议簇
*/
int getFamily() const { return m_family; }
/**
* @brief 获取类型
*/
int getType() const { return m_type; }
/**
* @brief 获取协议
*/
int getProtocol() const { return m_protocol; }
/**
* @brief 返回是否连接
*/
bool isConnected() const { return m_isConnected; }
/**
* @brief 是否有效(m_sock != -1)
*/
bool isValid() const;
/**
* @brief 返回Socket错误
*/
int getError();
/**
* @brief 输出信息到流中
*/
virtual std::ostream &dump(std::ostream &os) const;
virtual std::string toString() const;
/**
* @brief 返回socket句柄
*/
int getSocket() const { return m_sock; }
/**
* @brief 取消读
*/
bool cancelRead();
/**
* @brief 取消写
*/
bool cancelWrite();
/**
* @brief 取消accept
*/
bool cancelAccept();
/**
* @brief 取消所有事件
*/
bool cancelAll();
protected:
/**
* @brief 初始化socket
*/
void initSock();
/**
* @brief 创建socket
*/
void newSock();
/**
* @brief 初始化sock
*/
virtual bool init(int sock);
protected:
/// socket句柄
int m_sock;
/// 协议簇
int m_family;
/// 类型
int m_type;
/// 协议
int m_protocol;
/// 是否连接
bool m_isConnected;
/// 本地地址
Address::ptr m_localAddress;
/// 远端地址
Address::ptr m_remoteAddress;
};
/**
* @brief 流式输出socket
* @param[in, out] os 输出流
* @param[in] sock Socket类
*/
std::ostream &operator<<(std::ostream &os, const Socket &sock);
} // namespace sylar
#endif
socket.cc
#include "socket.h"
#include "iomanager.h"
#include "fd_manager.h"
#include "log.h"
#include "macro.h"
#include "hook.h"
#include <limits.h>
namespace sylar {
static sylar::Logger::ptr g_logger = SYLAR_LOG_NAME("system");
Socket::ptr Socket::CreateTCP(sylar::Address::ptr address) {
Socket::ptr sock(new Socket(address->getFamily(), TCP, 0));
return sock;
}
Socket::ptr Socket::CreateUDP(sylar::Address::ptr address) {
Socket::ptr sock(new Socket(address->getFamily(), UDP, 0));
sock->newSock();
sock->m_isConnected = true;
return sock;
}
Socket::ptr Socket::CreateTCPSocket() {
Socket::ptr sock(new Socket(IPv4, TCP, 0));
return sock;
}
Socket::ptr Socket::CreateUDPSocket() {
Socket::ptr sock(new Socket(IPv4, UDP, 0));
sock->newSock();
sock->m_isConnected = true;
return sock;
}
Socket::ptr Socket::CreateTCPSocket6() {
Socket::ptr sock(new Socket(IPv6, TCP, 0));
return sock;
}
Socket::ptr Socket::CreateUDPSocket6() {
Socket::ptr sock(new Socket(IPv6, UDP, 0));
sock->newSock();
sock->m_isConnected = true;
return sock;
}
Socket::ptr Socket::CreateUnixTCPSocket() {
Socket::ptr sock(new Socket(UNIX, TCP, 0));
return sock;
}
Socket::ptr Socket::CreateUnixUDPSocket() {
Socket::ptr sock(new Socket(UNIX, UDP, 0));
return sock;
}
Socket::Socket(int family, int type, int protocol)
: m_sock(-1)
, m_family(family)
, m_type(type)
, m_protocol(protocol)
, m_isConnected(false) {
}
Socket::~Socket() {
close();
}
int64_t Socket::getSendTimeout() {
FdCtx::ptr ctx = FdMgr::GetInstance()->get(m_sock);
if (ctx) {
return ctx->getTimeout(SO_SNDTIMEO);
}
return -1;
}
void Socket::setSendTimeout(int64_t v) {
struct timeval tv {
int(v / 1000), int(v % 1000 * 1000)
};
setOption(SOL_SOCKET, SO_SNDTIMEO, tv);
}
int64_t Socket::getRecvTimeout() {
FdCtx::ptr ctx = FdMgr::GetInstance()->get(m_sock);
if (ctx) {
return ctx->getTimeout(SO_RCVTIMEO);
}
return -1;
}
void Socket::setRecvTimeout(int64_t v) {
struct timeval tv {
int(v / 1000), int(v % 1000 * 1000)
};
setOption(SOL_SOCKET, SO_RCVTIMEO, tv);
}
bool Socket::getOption(int level, int option, void *result, socklen_t *len) {
int rt = getsockopt(m_sock, level, option, result, (socklen_t *)len);
if (rt) {
SYLAR_LOG_DEBUG(g_logger) << "getOption sock=" << m_sock
<< " level=" << level << " option=" << option
<< " errno=" << errno << " errstr=" << strerror(errno);
return false;
}
return true;
}
bool Socket::setOption(int level, int option, const void *result, socklen_t len) {
if (setsockopt(m_sock, level, option, result, (socklen_t)len)) {
SYLAR_LOG_DEBUG(g_logger) << "setOption sock=" << m_sock
<< " level=" << level << " option=" << option
<< " errno=" << errno << " errstr=" << strerror(errno);
return false;
}
return true;
}
Socket::ptr Socket::accept() {
Socket::ptr sock(new Socket(m_family, m_type, m_protocol));
int newsock = ::accept(m_sock, nullptr, nullptr);
if (newsock == -1) {
SYLAR_LOG_ERROR(g_logger) << "accept(" << m_sock << ") errno="
<< errno << " errstr=" << strerror(errno);
return nullptr;
}
if (sock->init(newsock)) {
return sock;
}
return nullptr;
}
bool Socket::init(int sock) {
FdCtx::ptr ctx = FdMgr::GetInstance()->get(sock);
if (ctx && ctx->isSocket() && !ctx->isClose()) {
m_sock = sock;
m_isConnected = true;
initSock();
getLocalAddress();
getRemoteAddress();
return true;
}
return false;
}
bool Socket::bind(const Address::ptr addr) {
m_localAddress = addr;
if (!isValid()) {
newSock();
if (SYLAR_UNLIKELY(!isValid())) {
return false;
}
}
if (SYLAR_UNLIKELY(addr->getFamily() != m_family)) {
SYLAR_LOG_ERROR(g_logger) << "bind sock.family("
<< m_family << ") addr.family(" << addr->getFamily()
<< ") not equal, addr=" << addr->toString();
return false;
}
UnixAddress::ptr uaddr = std::dynamic_pointer_cast<UnixAddress>(addr);
if (uaddr) {
Socket::ptr sock = Socket::CreateUnixTCPSocket();
if (sock->connect(uaddr)) {
return false;
} else {
sylar::FSUtil::Unlink(uaddr->getPath(), true);
}
}
if (::bind(m_sock, addr->getAddr(), addr->getAddrLen())) {
SYLAR_LOG_ERROR(g_logger) << "bind error errrno=" << errno
<< " errstr=" << strerror(errno);
return false;
}
getLocalAddress();
return true;
}
bool Socket::reconnect(uint64_t timeout_ms) {
if (!m_remoteAddress) {
SYLAR_LOG_ERROR(g_logger) << "reconnect m_remoteAddress is null";
return false;
}
m_localAddress.reset();
return connect(m_remoteAddress, timeout_ms);
}
bool Socket::connect(const Address::ptr addr, uint64_t timeout_ms) {
m_remoteAddress = addr;
if (!isValid()) {
newSock();
if (SYLAR_UNLIKELY(!isValid())) {
return false;
}
}
if (SYLAR_UNLIKELY(addr->getFamily() != m_family)) {
SYLAR_LOG_ERROR(g_logger) << "connect sock.family("
<< m_family << ") addr.family(" << addr->getFamily()
<< ") not equal, addr=" << addr->toString();
return false;
}
if (timeout_ms == (uint64_t)-1) {
if (::connect(m_sock, addr->getAddr(), addr->getAddrLen())) {
SYLAR_LOG_ERROR(g_logger) << "sock=" << m_sock << " connect(" << addr->toString()
<< ") error errno=" << errno << " errstr=" << strerror(errno);
close();
return false;
}
} else {
if (::connect_with_timeout(m_sock, addr->getAddr(), addr->getAddrLen(), timeout_ms)) {
SYLAR_LOG_ERROR(g_logger) << "sock=" << m_sock << " connect(" << addr->toString()
<< ") timeout=" << timeout_ms << " error errno="
<< errno << " errstr=" << strerror(errno);
close();
return false;
}
}
m_isConnected = true;
getRemoteAddress();
getLocalAddress();
return true;
}
bool Socket::listen(int backlog) {
if (!isValid()) {
SYLAR_LOG_ERROR(g_logger) << "listen error sock=-1";
return false;
}
if (::listen(m_sock, backlog)) {
SYLAR_LOG_ERROR(g_logger) << "listen error errno=" << errno
<< " errstr=" << strerror(errno);
return false;
}
return true;
}
bool Socket::close() {
if (!m_isConnected && m_sock == -1) {
return true;
}
m_isConnected = false;
if (m_sock != -1) {
::close(m_sock);
m_sock = -1;
}
return false;
}
int Socket::send(const void *buffer, size_t length, int flags) {
if (isConnected()) {
return ::send(m_sock, buffer, length, flags);
}
return -1;
}
int Socket::send(const iovec *buffers, size_t length, int flags) {
if (isConnected()) {
msghdr msg;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = (iovec *)buffers;
msg.msg_iovlen = length;
return ::sendmsg(m_sock, &msg, flags);
}
return -1;
}
int Socket::sendTo(const void *buffer, size_t length, const Address::ptr to, int flags) {
if (isConnected()) {
return ::sendto(m_sock, buffer, length, flags, to->getAddr(), to->getAddrLen());
}
return -1;
}
int Socket::sendTo(const iovec *buffers, size_t length, const Address::ptr to, int flags) {
if (isConnected()) {
msghdr msg;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = (iovec *)buffers;
msg.msg_iovlen = length;
msg.msg_name = to->getAddr();
msg.msg_namelen = to->getAddrLen();
return ::sendmsg(m_sock, &msg, flags);
}
return -1;
}
int Socket::recv(void *buffer, size_t length, int flags) {
if (isConnected()) {
return ::recv(m_sock, buffer, length, flags);
}
return -1;
}
int Socket::recv(iovec *buffers, size_t length, int flags) {
if (isConnected()) {
msghdr msg;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = (iovec *)buffers;
msg.msg_iovlen = length;
return ::recvmsg(m_sock, &msg, flags);
}
return -1;
}
int Socket::recvFrom(void *buffer, size_t length, Address::ptr from, int flags) {
if (isConnected()) {
socklen_t len = from->getAddrLen();
return ::recvfrom(m_sock, buffer, length, flags, from->getAddr(), &len);
}
return -1;
}
int Socket::recvFrom(iovec *buffers, size_t length, Address::ptr from, int flags) {
if (isConnected()) {
msghdr msg;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = (iovec *)buffers;
msg.msg_iovlen = length;
msg.msg_name = from->getAddr();
msg.msg_namelen = from->getAddrLen();
return ::recvmsg(m_sock, &msg, flags);
}
return -1;
}
Address::ptr Socket::getRemoteAddress() {
if (m_remoteAddress) {
return m_remoteAddress;
}
Address::ptr result;
switch (m_family) {
case AF_INET:
result.reset(new IPv4Address());
break;
case AF_INET6:
result.reset(new IPv6Address());
break;
case AF_UNIX:
result.reset(new UnixAddress());
break;
default:
result.reset(new UnknownAddress(m_family));
break;
}
socklen_t addrlen = result->getAddrLen();
if (getpeername(m_sock, result->getAddr(), &addrlen)) {
SYLAR_LOG_ERROR(g_logger) << "getpeername error sock=" << m_sock
<< " errno=" << errno << " errstr=" << strerror(errno);
return Address::ptr(new UnknownAddress(m_family));
}
if (m_family == AF_UNIX) {
UnixAddress::ptr addr = std::dynamic_pointer_cast<UnixAddress>(result);
addr->setAddrLen(addrlen);
}
m_remoteAddress = result;
return m_remoteAddress;
}
Address::ptr Socket::getLocalAddress() {
if (m_localAddress) {
return m_localAddress;
}
Address::ptr result;
switch (m_family) {
case AF_INET:
result.reset(new IPv4Address());
break;
case AF_INET6:
result.reset(new IPv6Address());
break;
case AF_UNIX:
result.reset(new UnixAddress());
break;
default:
result.reset(new UnknownAddress(m_family));
break;
}
socklen_t addrlen = result->getAddrLen();
if (getsockname(m_sock, result->getAddr(), &addrlen)) {
SYLAR_LOG_ERROR(g_logger) << "getsockname error sock=" << m_sock
<< " errno=" << errno << " errstr=" << strerror(errno);
return Address::ptr(new UnknownAddress(m_family));
}
if (m_family == AF_UNIX) {
UnixAddress::ptr addr = std::dynamic_pointer_cast<UnixAddress>(result);
addr->setAddrLen(addrlen);
}
m_localAddress = result;
return m_localAddress;
}
bool Socket::isValid() const {
return m_sock != -1;
}
int Socket::getError() {
int error = 0;
socklen_t len = sizeof(error);
if (!getOption(SOL_SOCKET, SO_ERROR, &error, &len)) {
error = errno;
}
return error;
}
std::ostream &Socket::dump(std::ostream &os) const {
os << "[Socket sock=" << m_sock
<< " is_connected=" << m_isConnected
<< " family=" << m_family
<< " type=" << m_type
<< " protocol=" << m_protocol;
if (m_localAddress) {
os << " local_address=" << m_localAddress->toString();
}
if (m_remoteAddress) {
os << " remote_address=" << m_remoteAddress->toString();
}
os << "]";
return os;
}
std::string Socket::toString() const {
std::stringstream ss;
dump(ss);
return ss.str();
}
bool Socket::cancelRead() {
return IOManager::GetThis()->cancelEvent(m_sock, sylar::IOManager::READ);
}
bool Socket::cancelWrite() {
return IOManager::GetThis()->cancelEvent(m_sock, sylar::IOManager::WRITE);
}
bool Socket::cancelAccept() {
return IOManager::GetThis()->cancelEvent(m_sock, sylar::IOManager::READ);
}
bool Socket::cancelAll() {
return IOManager::GetThis()->cancelAll(m_sock);
}
void Socket::initSock() {
int val = 1;
setOption(SOL_SOCKET, SO_REUSEADDR, val);
if (m_type == SOCK_STREAM) {
setOption(IPPROTO_TCP, TCP_NODELAY, val);
}
}
void Socket::newSock() {
m_sock = socket(m_family, m_type, m_protocol);
if (SYLAR_LIKELY(m_sock != -1)) {
initSock();
} else {
SYLAR_LOG_ERROR(g_logger) << "socket(" << m_family
<< ", " << m_type << ", " << m_protocol << ") errno="
<< errno << " errstr=" << strerror(errno);
}
}
std::ostream &operator<<(std::ostream &os, const Socket &sock) {
return sock.dump(os);
}
}
八、总结
socket模块并不复杂,其实只要理解的sylar的用意后,具体实现多看几遍就清楚了。
接下来的模块也都是封装为主,不像之前的协程那么复杂了。
【最后求关注、点赞、转发】
QQ交流群:957100923