Tcpserver.hpp
#include "Socket.hpp"
#include <functional>
using func_t = std::function<string(string &, bool *)>;
class Tcpserver;
class ThreadData
{
public:
ThreadData(Tcpserver *tcp_this, Net_work::Socket *sockp)
: _this(tcp_this), _sockp(sockp)
{
}
public:
Tcpserver *_this;
Net_work::Socket *_sockp;
};
class Tcpserver
{
public:
Tcpserver(uint16_t port, func_t handler_request) : _port(port),
_listensocket(new Net_work::Tcpsocket()), _handler_request(handler_request)
{
_listensocket->BuidListenSocketMethod(_port, 3);
}
static void *ThreadRun(void *argc)
{
pthread_detach(pthread_self());
ThreadData *data = static_cast<ThreadData *>(argc);
string inbuffer;
while (true)
{
if (!data->_sockp->Recv(&inbuffer, 1024))
break;
bool ok = true;
string send_message = data->_this->_handler_request(inbuffer, &ok);
if (ok)
{
if (!send_message.empty())
data->_sockp->Send(send_message);
}
else
{
break;
}
}
data->_sockp->CloseSocket();
delete data->_sockp;
delete data;
return nullptr;
}
void Loop()
{
while (true)
{
string peerip;
uint16_t peerport;
Net_work::Socket *newscocket = _listensocket->AcceptConnection(&peerip, &peerport);
if (newscocket == nullptr)
continue;
pthread_t tid;
ThreadData *td = new ThreadData(this, newscocket);
pthread_create(&tid, nullptr, ThreadRun, td);
}
}
~Tcpserver()
{
delete _listensocket;
}
private:
int _port;
Net_work::Socket *_listensocket;
public:
func_t _handler_request;
};
Socket.hpp
#pragma once
#include <iostream>
#include <string>
#include <iostream>
#include <string>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <unistd.h>
#include <string.h>
#include <arpa/inet.h>
using namespace std;
namespace Net_work
{
const static int defaultsocket = 1;
enum
{
SocketError = 1,
BindError,
ListenError,
};
class Socket
{
public:
virtual ~Socket() {}
virtual void CreatSocketOrdie() = 0;
virtual void BindSocketOrdie(uint16_t port) = 0;
virtual void ListenSocketOrdie(int backlog) = 0;
virtual Socket *AcceptConnection(string *peerip, uint16_t *peerport) = 0;
virtual bool ConnectServer(string &peerip, uint16_t &peerport) = 0;
virtual int Getsickfd() = 0;
virtual void Setsockfd(int socketfd) = 0;
virtual void CloseSocket() = 0;
virtual bool Recv(string *buffer, int size) = 0;
virtual bool Send(string &buffer) = 0;
public:
void BuidListenSocketMethod(uint16_t port, int backlog)
{
CreatSocketOrdie();
BindSocketOrdie(port);
ListenSocketOrdie(backlog);
}
bool BuildConnectSocketMethod(string &peerip, uint16_t &peerport)
{
CreatSocketOrdie();
return ConnectServer(peerip, peerport);
}
void BuildnomalSocketMethod(int socketfd)
{
Setsockfd(socketfd);
}
};
class Tcpsocket : public Socket
{
private:
int _sockfd;
public:
Tcpsocket(int tcpsocket = defaultsocket) : _sockfd(tcpsocket) {}
~Tcpsocket() {}
void CreatSocketOrdie() override
{
_sockfd = ::socket(AF_INET, SOCK_STREAM, 0);
if (_sockfd < 0)
exit(SocketError);
}
void BindSocketOrdie(uint16_t port) override
{
struct sockaddr_in local;
memset(&local, 0, sizeof(local));
local.sin_family = AF_INET;
local.sin_addr.s_addr = INADDR_ANY;
local.sin_port = htons(port);
int n = ::bind(_sockfd, (sockaddr *)&local, sizeof(local));
if (n < 0)
exit(BindError);
}
void ListenSocketOrdie(int backlog) override
{
int n = listen(_sockfd, backlog);
if (n < 0)
exit(ListenError);
}
Socket *AcceptConnection(string *peerip, uint16_t *peerport) override
{
struct sockaddr_in peer;
socklen_t len = sizeof(peer);
int newsockfd = ::accept(_sockfd, (sockaddr *)&peer, &len);
if (newsockfd < 0)
return nullptr;
*peerip = inet_ntoa(peer.sin_addr);
*peerport = ntohs(peer.sin_port);
Socket *s = new Tcpsocket(newsockfd);
return s;
}
bool ConnectServer(string &peerip, uint16_t &peerport) override
{
struct sockaddr_in server;
server.sin_addr.s_addr = inet_addr(peerip.c_str());
server.sin_family = AF_INET;
server.sin_port = htons(peerport);
int n = ::connect(_sockfd, (sockaddr *)&server, sizeof(server));
if (n == 0)
return true;
else
return false;
}
int Getsickfd() override
{
return _sockfd;
}
void Setsockfd(int socketfd) override
{
_sockfd = socketfd;
}
void CloseSocket() override
{
if (_sockfd > defaultsocket)
::close(_sockfd);
}
bool Recv(string *buffer, int size) override
{
char message[size];
int a = recv(_sockfd, &message, size-1,0);
if (a > 0)
{
message[a] = 0;
*buffer += message;
return true;
}
else
return false;
}
bool Send(string &buffer) override
{
write(_sockfd,buffer.c_str(),buffer.size());
return true;
}
};
}
Protocol.hpp
```cpp
```cpp
```cpp
#pragma once
#include <iostream>
#include <memory>
using namespace std;
// 序列和反序列化的问题
const string ProtSep = " ";
const string LineSep = "\n";
// 添加报头,然后返回一个新的字符串
// "len\nx op y\n"
string Encode(const string &message)
{
string len = std::to_string(message.size());
string package = len + LineSep + message + LineSep;
return package;
}
// "l
// "len
// "len\n
// "len\nx
// "len\nx op
// "len\nx op y
// "len\nx op y\n"
// "len\nx op y\n""len
// "len\nx op y\n" 变为x op y
// 无法保证package的完整性,所以先判断
bool Decode(string &package, string *message)
{
// 提取len
int pos = package.find(LineSep);
//有些错误判断还是必须要写的,不仅找错误方便,程序也更稳定
if (pos == string::npos)
return false;
string len = package.substr(0, pos);
int messagelen = stoi(len);
// 判断是否是一个完整的报文
int total = len.size() + messagelen + 2 * LineSep.size();
if (package.size() < total)
return false;
// 是一个完整的报文开始处理
*message = package.substr(pos + LineSep.size(), messagelen);
// *message=package.substr(len.size()-1+LineSep.size(),messagelen);
package.erase(0, total);
return true;
}
class Request
{
public:
Request() {}
Request(int data_x, int data_y, char oper) : _data_y(data_y), _data_x(data_x), _oper(oper)
{
}
void debug()
{
cout << _data_x << " " << _data_y << " " << _oper << endl;
}
void add()
{
_data_x++;
_data_y++;
}
// 序列化
// 添加报头的工作分开
bool Serialize(string *out)
{
*out = to_string(_data_x) + ProtSep + _oper + ProtSep + to_string(_data_y);
return true;
}
bool Deserialize(string &in)
{
size_t left = in.find(ProtSep);
// static const size_typenpos = static_cast<size_type>(-1);
if (left == string::npos)
return false;
size_t right = in.rfind(ProtSep);
if (right == string::npos)
return false;
// (size_t pos = 0, size_t len = npos)(从哪里开始拷贝,拷贝多少个)
// 反回值: string,包含s中从pos开始的len个字符的拷贝
// 若pos的值超过了sting的大小,则substr函数会抛出一个out ofrange异常;
// 若pos+n的值超过了string的大小,则substr会调整n内值,只拷贝到string的末尾
_data_x = std::stoi(in.substr(0, left));
string oper = in.substr(left + ProtSep.size(), right - left - ProtSep.size());
if (oper.size() != 1)
return false;
_oper = oper[0];
_data_y = std::stoi(in.substr(right + ProtSep.size(), string::npos));
return true;
}
int GetX() { return _data_x; }
int GetY() { return _data_y; }
char GetOper() { return _oper; }
private:
// "len\nx op y\n" 第一个\n保证读到一个len 第二个\n报文边界,方便打印
// len报文的长度:万一字符串中有\n的话就不行了,所以要加len字段
int _data_x;
int _data_y;
char _oper;
};
class Response
{
public:
Response()
{
}
Response(int result, int code) : _result(result), _code(code) {}
bool Serialize(string *out)
{
*out = to_string(_result) + ProtSep + to_string(_code);
return true;
}
// 不需要输出的就用引用,
bool Deserialize(string &in)
{
size_t pos = in.find(ProtSep);
if (pos == string::npos)
return false;
_result = stoi(in.substr(0, pos));
_code = stoi(in.substr(pos + ProtSep.size(), string::npos));
return true;
}
//"result code"
void SetResult(int res) { _result = res; }
void SetCode(int code) { _code = code; }
int GetResult() { return _result; }
int GetCode() { return _code; }
private:
int _result;
int _code;
};
class Factory
{
public:
shared_ptr<Request> BuildRequest()
{
shared_ptr<Request> req = make_shared<Request>();
return req;
}
shared_ptr<Request> BuildRequest(int x, int y, int op)
{
shared_ptr<Request> req = make_shared<Request>(x, y, op);
return req;
}
shared_ptr<Response> BuildResponse()
{
shared_ptr<Response> resp = make_shared<Response>();
return resp;
}
shared_ptr<Response> BuildResponse(int result, int code)
{
shared_ptr<Response> resp = make_shared<Response>(result, code);
return resp;
}
};
// make_shared的实现原理源码没看懂,一会去b站上找找
// make_shared和shared_ptr的区别
// make_shared只会申请一次内存,这块内存会大于int所占用的内存,多出的部分被用于智能指针引用计数。这样就避免了直接使用shared_ptr带来的问题。
// 而shared_ptr会申请两次内存,一次是指针指向要管理的对象,一次是引用计数
// 智能指针的实现原理,也就是一个模板类,跟vector,list的使用是一样的
// 我里面就两个在栈上开辟的空间,我走了之后,也会把开辟的空间带走
// template<typename T>
// class shared_ptr {
// public:
// // constructor
// shared_ptr(T* ptr = nullptr) : m_ptr(ptr), m_refCount(new int(1)) {}
// // copy constructor
// shared_ptr(const shared_ptr& other) : m_ptr(other.m_ptr), m_refCount(other.m_refCount) {
// // increase the reference count
// (*m_refCount)++;
// }
// // destructor
// ~shared_ptr() {
// // decrease the reference count
// (*m_refCount)--;
// // if the reference count is zero, delete the pointer
// if (*m_refCount == 0) {
// delete m_ptr;
// delete m_refCount;
// }
// }
// // overload operator=()
// shared_ptr& operator=(const shared_ptr& other) {
// // check self-assignment
// if (this != &other) {
// // decrease the reference count for the old pointer
// (*m_refCount)--;
// // if the reference count is zero, delete the pointer
// if (*m_refCount == 0) {
// delete m_ptr;
// delete m_refCount;
// }
// // copy the data and reference pointer and increase the reference count
// m_ptr = other.m_ptr;
// m_refCount = other.m_refCount;
// // increase the reference count
// (*m_refCount)++;
// }
// return *this;
// }
// private:
// T* m_ptr; // points to the actual data
// int* m_refCount; // reference count
// };
Calculate.hpp
#include "Protocol.hpp"
enum
{
Success = 0,
Divzeroerr,
Modzeroerr,
Unknowcode
};
class Calculate
{
private:
Factory factory;
public:
Calculate(){};
std::shared_ptr<Response> Cal(std::shared_ptr<Request> req)
{
shared_ptr<Response> resp = factory.BuildResponse();
resp->SetCode(Success);
switch (req->GetOper())
{
case '+':
resp->SetResult(req->GetX() + req->GetY());
break;
case '-':
resp->SetResult(req->GetX() - req->GetY());
break;
case '*':
resp->SetResult(req->GetX() * req->GetY());
break;
case '/':
{
if (req->GetY() == 0)
resp->SetCode(Divzeroerr);
else
resp->SetResult(req->GetX() / req->GetY());
}
break;
case '%':
if (req->GetY() == 0)
resp->SetCode(Modzeroerr);
else
resp->SetResult(req->GetX() % req->GetY());
break;
default:
resp->SetCode(Unknowcode);
break;
}
return resp;
}
~Calculate(){};
};
Tcpservermain.cc
```cpp
#include "Protocol.hpp"
#include "Tcpserver.hpp"
#include "Socket.hpp"
#include "Calculate.hpp"
string handle(string &inbuffer, bool *error_code)
{
*error_code = true;
Calculate calculate;
unique_ptr<Factory> factory = make_unique<Factory>();
auto req = factory->BuildRequest();
auto resp = factory->BuildResponse();
string message;
string total_send_message;
while (Decode(inbuffer, &message))
{
if (!req->Deserialize(message))
{
*error_code = false;
return string();
}
resp = calculate.Cal(req);
string send_string;
resp->Serialize(&send_string);
string send_massage = Encode(send_string);
total_send_message += send_massage;
}
return total_send_message;
}
int main(int argc, char *argv[])
{
if (argc != 2)
{
cout << "Usage : " << argv[0] << " port" << endl;
}
uint16_t localport = stoi(argv[1]);
Tcpserver server(localport, handle);
server.Loop();
return 0;
}
TcpClientmain.cc
#include "Socket.hpp"
#include "Protocol.hpp"
#include <unistd.h>
#include <stdlib.h>
#include <time.h>
#include "Socket.hpp"
int main(int argc, char *argv[])
{
srand(time(NULL) ^ getpid());
if (argc != 3)
{
cout << "Usage: " << argv[0] << " ip" << " port" << endl;
}
Net_work::Socket *client = new Net_work::Tcpsocket();
string serverip = argv[1];
uint16_t serverport = stoi(argv[2]);
bool n = client->BuildConnectSocketMethod(serverip, serverport);
if (n == false)
cout << "connect is error" << endl;
else
{
cout << "connect success " << endl;
cout << "serverip:" << serverip << " serverport" << serverport << endl;
}
unique_ptr<Factory> factory = make_unique<Factory>();
const string opers = "+-*/^!=";
while (true)
{
int x = rand() % 100;
int y = rand() % 100;
char opar = opers[rand() % (opers.size())];
shared_ptr<Request> req = factory->BuildRequest(x, y, opar);
string message;
req->Serialize(&message);
string send_message = Encode(message);
client->Send(send_message);
cout<<send_message;
while (true)
{
string resp_messsage;
client->Recv(&resp_messsage, 1024);
string re_mess;
if (!Decode(resp_messsage, &re_mess))
continue;
auto resp = factory->BuildResponse();
resp->Deserialize(re_mess);
cout << resp->GetResult() << " " << resp->GetCode() << endl;
break;
}
sleep(1);
}
client->CloseSocket();
return 0;
}