blockqueue.h
template<class T>
class BlockDeque {
public:
explicit BlockDeque(size_t MaxCapacity = 1000);
~BlockDeque();
void clear();
bool empty();
bool full();
void Close();
size_t size();
size_t capacity();
T front();
T back();
void push_back(const T &item);
void push_front(const T &item);
bool pop(T &item);
bool pop(T &item, int timeout);
void flush();
private:
std::deque<T> deq_;
size_t capacity_;
std::mutex mtx_;
bool isClose_;
std::condition_variable condConsumer_;
std::condition_variable condProducer_;
};
template<class T>
BlockDeque<T>::BlockDeque(size_t MaxCapacity) :capacity_(MaxCapacity) {
assert(MaxCapacity > 0);
isClose_ = false;
}
template<class T>
BlockDeque<T>::~BlockDeque() {
Close();
};
template<class T>
void BlockDeque<T>::Close() {
{
std::lock_guard<std::mutex> locker(mtx_);
deq_.clear();
isClose_ = true;
}
condProducer_.notify_all();
condConsumer_.notify_all();
};
template<class T>
void BlockDeque<T>::flush() {
condConsumer_.notify_one();
};
template<class T>
void BlockDeque<T>::clear() {
std::lock_guard<std::mutex> locker(mtx_);
deq_.clear();
}
template<class T>
T BlockDeque<T>::front() {
std::lock_guard<std::mutex> locker(mtx_);
return deq_.front();
}
template<class T>
T BlockDeque<T>::back() {
std::lock_guard<std::mutex> locker(mtx_);
return deq_.back();
}
template<class T>
size_t BlockDeque<T>::size() {
std::lock_guard<std::mutex> locker(mtx_);
return deq_.size();
}
template<class T>
size_t BlockDeque<T>::capacity() {
std::lock_guard<std::mutex> locker(mtx_);
return capacity_;
}
template<class T>
void BlockDeque<T>::push_back(const T &item) {
std::unique_lock<std::mutex> locker(mtx_);
while(deq_.size() >= capacity_) {
condProducer_.wait(locker);
}
deq_.push_back(item);
condConsumer_.notify_one();
}
template<class T>
void BlockDeque<T>::push_front(const T &item) {
std::unique_lock<std::mutex> locker(mtx_);
while(deq_.size() >= capacity_) {
condProducer_.wait(locker);
}
deq_.push_front(item);
condConsumer_.notify_one();
}
template<class T>
bool BlockDeque<T>::empty() {
std::lock_guard<std::mutex> locker(mtx_);
return deq_.empty();
}
template<class T>
bool BlockDeque<T>::full(){
std::lock_guard<std::mutex> locker(mtx_);
return deq_.size() >= capacity_;
}
template<class T>
bool BlockDeque<T>::pop(T &item) {
std::unique_lock<std::mutex> locker(mtx_);
while(deq_.empty()){
condConsumer_.wait(locker);
if(isClose_){
return false;
}
}
item = deq_.front();
deq_.pop_front();
condProducer_.notify_one();
return true;
}
template<class T>
bool BlockDeque<T>::pop(T &item, int timeout) {
std::unique_lock<std::mutex> locker(mtx_);
while(deq_.empty()){
if(condConsumer_.wait_for(locker, std::chrono::seconds(timeout))
== std::cv_status::timeout){
return false;
}
if(isClose_){
return false;
}
}
item = deq_.front();
deq_.pop_front();
condProducer_.notify_one();
return true;
}
buffer.h
class Buffer {
public:
Buffer(int initBuffSize = 1024);
~Buffer() = default;
size_t WritableBytes() const;
size_t ReadableBytes() const ;
size_t PrependableBytes() const;
const char* Peek() const;
void EnsureWriteable(size_t len);
void HasWritten(size_t len);
void Retrieve(size_t len);
void RetrieveUntil(const char* end);
void RetrieveAll() ;
std::string RetrieveAllToStr();
const char* BeginWriteConst() const;
char* BeginWrite();
void Append(const std::string& str);
void Append(const char* str, size_t len);
void Append(const void* data, size_t len);
void Append(const Buffer& buff);
ssize_t ReadFd(int fd, int* Errno);
ssize_t WriteFd(int fd, int* Errno);
private:
char* BeginPtr_();
const char* BeginPtr_() const;
void MakeSpace_(size_t len);
std::vector<char> buffer_;
std::atomic<std::size_t> readPos_;
std::atomic<std::size_t> writePos_;
};
buffer.cpp
#include "buffer.h"
Buffer::Buffer(int initBuffSize) : buffer_(initBuffSize), readPos_(0), writePos_(0) {}
size_t Buffer::ReadableBytes() const {
return writePos_ - readPos_;
}
size_t Buffer::WritableBytes() const {
return buffer_.size() - writePos_;
}
size_t Buffer::PrependableBytes() const {
return readPos_;
}
const char* Buffer::Peek() const {
return BeginPtr_() + readPos_;
}
void Buffer::Retrieve(size_t len) {
assert(len <= ReadableBytes());
readPos_ += len;
}
void Buffer::RetrieveUntil(const char* end) {
assert(Peek() <= end );
Retrieve(end - Peek());
}
void Buffer::RetrieveAll() {
bzero(&buffer_[0], buffer_.size());
readPos_ = 0;
writePos_ = 0;
}
std::string Buffer::RetrieveAllToStr() {
std::string str(Peek(), ReadableBytes());
RetrieveAll();
return str;
}
const char* Buffer::BeginWriteConst() const {
return BeginPtr_() + writePos_;
}
char* Buffer::BeginWrite() {
return BeginPtr_() + writePos_;
}
void Buffer::HasWritten(size_t len) {
writePos_ += len;
}
void Buffer::Append(const std::string& str) {
Append(str.data(), str.length());
}
void Buffer::Append(const void* data, size_t len) {
assert(data);
Append(static_cast<const char*>(data), len);
}
void Buffer::Append(const char* str, size_t len) {
assert(str);
EnsureWriteable(len);
std::copy(str, str + len, BeginWrite());
HasWritten(len);
}
void Buffer::Append(const Buffer& buff) {
Append(buff.Peek(), buff.ReadableBytes());
}
void Buffer::EnsureWriteable(size_t len) {
if(WritableBytes() < len) {
MakeSpace_(len);
}
assert(WritableBytes() >= len);
}
ssize_t Buffer::ReadFd(int fd, int* saveErrno) {
char buff[65535];
struct iovec iov[2];
const size_t writable = WritableBytes();
/* 分散读, 保证数据全部读完 */
iov[0].iov_base = BeginPtr_() + writePos_;
iov[0].iov_len = writable;
iov[1].iov_base = buff;
iov[1].iov_len = sizeof(buff);
const ssize_t len = readv(fd, iov, 2);
if(len < 0) {
*saveErrno = errno;
}
else if(static_cast<size_t>(len) <= writable) {
writePos_ += len;
}
else {
writePos_ = buffer_.size();
Append(buff, len - writable);
}
return len;
}
ssize_t Buffer::WriteFd(int fd, int* saveErrno) {
size_t readSize = ReadableBytes();
ssize_t len = write(fd, Peek(), readSize);
if(len < 0) {
*saveErrno = errno;
return len;
}
readPos_ += len;
return len;
}
char* Buffer::BeginPtr_() {
return &*buffer_.begin();
}
const char* Buffer::BeginPtr_() const {
return &*buffer_.begin();
}
void Buffer::MakeSpace_(size_t len) {
if(WritableBytes() + PrependableBytes() < len) {
buffer_.resize(writePos_ + len + 1);
}
else {
size_t readable = ReadableBytes();
std::copy(BeginPtr_() + readPos_, BeginPtr_() + writePos_, BeginPtr_());
readPos_ = 0;
writePos_ = readPos_ + readable;
assert(readable == ReadableBytes());
}
}
heaptimer.h
typedef std::function<void()> TimeoutCallBack;
typedef std::chrono::high_resolution_clock Clock;
typedef std::chrono::milliseconds MS;
typedef Clock::time_point TimeStamp;
struct TimerNode {
int id;
TimeStamp expires;
TimeoutCallBack cb;
bool operator<(const TimerNode& t) {
return expires < t.expires;
}
};
class HeapTimer {
public:
HeapTimer() { heap_.reserve(64); }
~HeapTimer() { clear(); }
void adjust(int id, int newExpires);
void add(int id, int timeOut, const TimeoutCallBack& cb);
void doWork(int id);
void clear();
void tick();
void pop();
int GetNextTick();
private:
void del_(size_t i);
void siftup_(size_t i);
bool siftdown_(size_t index, size_t n);
void SwapNode_(size_t i, size_t j);
std::vector<TimerNode> heap_;
std::unordered_map<int, size_t> ref_;
};
heaptimer.cpp
#include "heaptimer.h"
void HeapTimer::siftup_(size_t i) {
assert(i >= 0 && i < heap_.size());
size_t j = (i - 1) / 2;
while(j >= 0) {
if(heap_[j] < heap_[i]) { break; }
SwapNode_(i, j);
i = j;
j = (i - 1) / 2;
}
}
void HeapTimer::SwapNode_(size_t i, size_t j) {
assert(i >= 0 && i < heap_.size());
assert(j >= 0 && j < heap_.size());
std::swap(heap_[i], heap_[j]);
ref_[heap_[i].id] = i;
ref_[heap_[j].id] = j;
}
bool HeapTimer::siftdown_(size_t index, size_t n) {
assert(index >= 0 && index < heap_.size());
assert(n >= 0 && n <= heap_.size());
size_t i = index;
size_t j = i * 2 + 1;
while(j < n) {
if(j + 1 < n && heap_[j + 1] < heap_[j]) j++;
if(heap_[i] < heap_[j]) break;
SwapNode_(i, j);
i = j;
j = i * 2 + 1;
}
return i > index;
}
void HeapTimer::add(int id, int timeout, const TimeoutCallBack& cb) {
assert(id >= 0);
size_t i;
if(ref_.count(id) == 0) {
/* 新节点:堆尾插入,调整堆 */
i = heap_.size();
ref_[id] = i;
heap_.push_back({id, Clock::now() + MS(timeout), cb});
siftup_(i);
}
else {
/* 已有结点:调整堆 */
i = ref_[id];
heap_[i].expires = Clock::now() + MS(timeout);
heap_[i].cb = cb;
if(!siftdown_(i, heap_.size())) {
siftup_(i);
}
}
}
void HeapTimer::doWork(int id) {
/* 删除指定id结点,并触发回调函数 */
if(heap_.empty() || ref_.count(id) == 0) {
return;
}
size_t i = ref_[id];
TimerNode node = heap_[i];
node.cb();
del_(i);
}
void HeapTimer::del_(size_t index) {
/* 删除指定位置的结点 */
assert(!heap_.empty() && index >= 0 && index < heap_.size());
/* 将要删除的结点换到队尾,然后调整堆 */
size_t i = index;
size_t n = heap_.size() - 1;
assert(i <= n);
if(i < n) {
SwapNode_(i, n);
if(!siftdown_(i, n)) {
siftup_(i);
}
}
/* 队尾元素删除 */
ref_.erase(heap_.back().id);
heap_.pop_back();
}
void HeapTimer::adjust(int id, int timeout) {
/* 调整指定id的结点 */
assert(!heap_.empty() && ref_.count(id) > 0);
heap_[ref_[id]].expires = Clock::now() + MS(timeout);;
siftdown_(ref_[id], heap_.size());
}
void HeapTimer::tick() {
/* 清除超时结点 */
if(heap_.empty()) {
return;
}
while(!heap_.empty()) {
TimerNode node = heap_.front();
if(std::chrono::duration_cast<MS>(node.expires - Clock::now()).count() > 0) {
break;
}
node.cb();
pop();
}
}
void HeapTimer::pop() {
assert(!heap_.empty());
del_(0);
}
void HeapTimer::clear() {
ref_.clear();
heap_.clear();
}
int HeapTimer::GetNextTick() {
tick();
size_t res = -1;
if(!heap_.empty()) {
res = std::chrono::duration_cast<MS>(heap_.front().expires - Clock::now()).count();
if(res < 0) { res = 0; }
}
return res;
}
httpconn.h
class HttpConn {
public:
HttpConn();
~HttpConn();
void init(int sockFd, const sockaddr_in& addr);
ssize_t read(int* saveErrno);
ssize_t write(int* saveErrno);
void Close();
int GetFd() const;
int GetPort() const;
const char* GetIP() const;
sockaddr_in GetAddr() const;
bool process();
int ToWriteBytes() {
return iov_[0].iov_len + iov_[1].iov_len;
}
bool IsKeepAlive() const {
return request_.IsKeepAlive();
}
static bool isET;
static const char* srcDir;
static std::atomic<int> userCount;
private:
int fd_;
struct sockaddr_in addr_;
bool isClose_;
int iovCnt_;
struct iovec iov_[2];
Buffer readBuff_; // 读缓冲区
Buffer writeBuff_; // 写缓冲区
HttpRequest request_;
HttpResponse response_;
};
httpconn.cpp
#include "httpconn.h"
using namespace std;
const char* HttpConn::srcDir;
std::atomic<int> HttpConn::userCount;
bool HttpConn::isET;
HttpConn::HttpConn() {
fd_ = -1;
addr_ = { 0 };
isClose_ = true;
};
HttpConn::~HttpConn() {
Close();
};
void HttpConn::init(int fd, const sockaddr_in& addr) {
assert(fd > 0);
userCount++;
addr_ = addr;
fd_ = fd;
writeBuff_.RetrieveAll();
readBuff_.RetrieveAll();
isClose_ = false;
LOG_INFO("Client[%d](%s:%d) in, userCount:%d", fd_, GetIP(), GetPort(), (int)userCount);
}
void HttpConn::Close() {
response_.UnmapFile();
if(isClose_ == false){
isClose_ = true;
userCount--;
close(fd_);
LOG_INFO("Client[%d](%s:%d) quit, UserCount:%d", fd_, GetIP(), GetPort(), (int)userCount);
}
}
int HttpConn::GetFd() const {
return fd_;
};
struct sockaddr_in HttpConn::GetAddr() const {
return addr_;
}
const char* HttpConn::GetIP() const {
return inet_ntoa(addr_.sin_addr);
}
int HttpConn::GetPort() const {
return addr_.sin_port;
}
ssize_t HttpConn::read(int* saveErrno) {
ssize_t len = -1;
do {
len = readBuff_.ReadFd(fd_, saveErrno);
if (len <= 0) {
break;
}
} while (isET);
return len;
}
ssize_t HttpConn::write(int* saveErrno) {
ssize_t len = -1;
do {
len = writev(fd_, iov_, iovCnt_);
if(len <= 0) {
*saveErrno = errno;
break;
}
if(iov_[0].iov_len + iov_[1].iov_len == 0) { break; } /* 传输结束 */
else if(static_cast<size_t>(len) > iov_[0].iov_len) {
iov_[1].iov_base = (uint8_t*) iov_[1].iov_base + (len - iov_[0].iov_len);
iov_[1].iov_len -= (len - iov_[0].iov_len);
if(iov_[0].iov_len) {
writeBuff_.RetrieveAll();
iov_[0].iov_len = 0;
}
}
else {
iov_[0].iov_base = (uint8_t*)iov_[0].iov_base + len;
iov_[0].iov_len -= len;
writeBuff_.Retrieve(len);
}
} while(isET || ToWriteBytes() > 10240);
return len;
}
bool HttpConn::process() {
request_.Init();
if(readBuff_.ReadableBytes() <= 0) {
return false;
}
else if(request_.parse(readBuff_)) {
LOG_DEBUG("%s", request_.path().c_str());
response_.Init(srcDir, request_.path(), request_.IsKeepAlive(), 200);
} else {
response_.Init(srcDir, request_.path(), false, 400);
}
response_.MakeResponse(writeBuff_);
/* 响应头 */
iov_[0].iov_base = const_cast<char*>(writeBuff_.Peek());
iov_[0].iov_len = writeBuff_.ReadableBytes();
iovCnt_ = 1;
/* 文件 */
if(response_.FileLen() > 0 && response_.File()) {
iov_[1].iov_base = response_.File();
iov_[1].iov_len = response_.FileLen();
iovCnt_ = 2;
}
LOG_DEBUG("filesize:%d, %d to %d", response_.FileLen() , iovCnt_, ToWriteBytes());
return true;
}
httprequest.h
class HttpRequest {
public:
enum PARSE_STATE {
REQUEST_LINE,
HEADERS,
BODY,
FINISH,
};
enum HTTP_CODE {
NO_REQUEST = 0,
GET_REQUEST,
BAD_REQUEST,
NO_RESOURSE,
FORBIDDENT_REQUEST,
FILE_REQUEST,
INTERNAL_ERROR,
CLOSED_CONNECTION,
};
HttpRequest() { Init(); }
~HttpRequest() = default;
void Init();
bool parse(Buffer& buff);
std::string path() const;
std::string& path();
std::string method() const;
std::string version() const;
std::string GetPost(const std::string& key) const;
std::string GetPost(const char* key) const;
bool IsKeepAlive() const;
/*
todo
void HttpConn::ParseFormData() {}
void HttpConn::ParseJson() {}
*/
private:
bool ParseRequestLine_(const std::string& line);
void ParseHeader_(const std::string& line);
void ParseBody_(const std::string& line);
void ParsePath_();
void ParsePost_();
void ParseFromUrlencoded_();
static bool UserVerify(const std::string& name, const std::string& pwd, bool isLogin);
PARSE_STATE state_;
std::string method_, path_, version_, body_;
std::unordered_map<std::string, std::string> header_;
std::unordered_map<std::string, std::string> post_;
static const std::unordered_set<std::string> DEFAULT_HTML;
static const std::unordered_map<std::string, int> DEFAULT_HTML_TAG;
static int ConverHex(char ch);
};
httprequest.cpp
#include "httprequest.h"
using namespace std;
const unordered_set<string> HttpRequest::DEFAULT_HTML{
"/index", "/register", "/login",
"/welcome", "/video", "/picture", };
const unordered_map<string, int> HttpRequest::DEFAULT_HTML_TAG {
{"/register.html", 0}, {"/login.html", 1}, };
void HttpRequest::Init() {
method_ = path_ = version_ = body_ = "";
state_ = REQUEST_LINE;
header_.clear();
post_.clear();
}
bool HttpRequest::IsKeepAlive() const {
if(header_.count("Connection") == 1) {
return header_.find("Connection")->second == "keep-alive" && version_ == "1.1";
}
return false;
}
bool HttpRequest::parse(Buffer& buff) {
const char CRLF[] = "\r\n";
if(buff.ReadableBytes() <= 0) {
return false;
}
while(buff.ReadableBytes() && state_ != FINISH) {
const char* lineEnd = search(buff.Peek(), buff.BeginWriteConst(), CRLF, CRLF + 2);
std::string line(buff.Peek(), lineEnd);
switch(state_)
{
case REQUEST_LINE:
if(!ParseRequestLine_(line)) {
return false;
}
ParsePath_();
break;
case HEADERS:
ParseHeader_(line);
if(buff.ReadableBytes() <= 2) {
state_ = FINISH;
}
break;
case BODY:
ParseBody_(line);
break;
default:
break;
}
if(lineEnd == buff.BeginWrite()) { break; }
buff.RetrieveUntil(lineEnd + 2);
}
LOG_DEBUG("[%s], [%s], [%s]", method_.c_str(), path_.c_str(), version_.c_str());
return true;
}
void HttpRequest::ParsePath_() {
if(path_ == "/") {
path_ = "/index.html";
}
else {
for(auto &item: DEFAULT_HTML) {
if(item == path_) {
path_ += ".html";
break;
}
}
}
}
bool HttpRequest::ParseRequestLine_(const string& line) {
regex patten("^([^ ]*) ([^ ]*) HTTP/([^ ]*)$");
smatch subMatch;
if(regex_match(line, subMatch, patten)) {
method_ = subMatch[1];
path_ = subMatch[2];
version_ = subMatch[3];
state_ = HEADERS;
return true;
}
LOG_ERROR("RequestLine Error");
return false;
}
void HttpRequest::ParseHeader_(const string& line) {
regex patten("^([^:]*): ?(.*)$");
smatch subMatch;
if(regex_match(line, subMatch, patten)) {
header_[subMatch[1]] = subMatch[2];
}
else {
state_ = BODY;
}
}
void HttpRequest::ParseBody_(const string& line) {
body_ = line;
ParsePost_();
state_ = FINISH;
LOG_DEBUG("Body:%s, len:%d", line.c_str(), line.size());
}
int HttpRequest::ConverHex(char ch) {
if(ch >= 'A' && ch <= 'F') return ch -'A' + 10;
if(ch >= 'a' && ch <= 'f') return ch -'a' + 10;
return ch;
}
void HttpRequest::ParsePost_() {
if(method_ == "POST" && header_["Content-Type"] == "application/x-www-form-urlencoded") {
ParseFromUrlencoded_();
if(DEFAULT_HTML_TAG.count(path_)) {
int tag = DEFAULT_HTML_TAG.find(path_)->second;
LOG_DEBUG("Tag:%d", tag);
if(tag == 0 || tag == 1) {
bool isLogin = (tag == 1);
if(UserVerify(post_["username"], post_["password"], isLogin)) {
path_ = "/welcome.html";
}
else {
path_ = "/error.html";
}
}
}
}
}
void HttpRequest::ParseFromUrlencoded_() {
if(body_.size() == 0) { return; }
string key, value;
int num = 0;
int n = body_.size();
int i = 0, j = 0;
for(; i < n; i++) {
char ch = body_[i];
switch (ch) {
case '=':
key = body_.substr(j, i - j);
j = i + 1;
break;
case '+':
body_[i] = ' ';
break;
case '%':
num = ConverHex(body_[i + 1]) * 16 + ConverHex(body_[i + 2]);
body_[i + 2] = num % 10 + '0';
body_[i + 1] = num / 10 + '0';
i += 2;
break;
case '&':
value = body_.substr(j, i - j);
j = i + 1;
post_[key] = value;
LOG_DEBUG("%s = %s", key.c_str(), value.c_str());
break;
default:
break;
}
}
assert(j <= i);
if(post_.count(key) == 0 && j < i) {
value = body_.substr(j, i - j);
post_[key] = value;
}
}
bool HttpRequest::UserVerify(const string &name, const string &pwd, bool isLogin) {
if(name == "" || pwd == "") { return false; }
LOG_INFO("Verify name:%s pwd:%s", name.c_str(), pwd.c_str());
MYSQL* sql;
SqlConnRAII(&sql, SqlConnPool::Instance());
assert(sql);
bool flag = false;
unsigned int j = 0;
char order[256] = { 0 };
MYSQL_FIELD *fields = nullptr;
MYSQL_RES *res = nullptr;
if(!isLogin) { flag = true; }
/* 查询用户及密码 */
snprintf(order, 256, "SELECT username, password FROM user WHERE username='%s' LIMIT 1", name.c_str());
LOG_DEBUG("%s", order);
if(mysql_query(sql, order)) {
mysql_free_result(res);
return false;
}
res = mysql_store_result(sql);
j = mysql_num_fields(res);
fields = mysql_fetch_fields(res);
while(MYSQL_ROW row = mysql_fetch_row(res)) {
LOG_DEBUG("MYSQL ROW: %s %s", row[0], row[1]);
string password(row[1]);
/* 注册行为 且 用户名未被使用*/
if(isLogin) {
if(pwd == password) { flag = true; }
else {
flag = false;
LOG_DEBUG("pwd error!");
}
}
else {
flag = false;
LOG_DEBUG("user used!");
}
}
mysql_free_result(res);
/* 注册行为 且 用户名未被使用*/
if(!isLogin && flag == true) {
LOG_DEBUG("regirster!");
bzero(order, 256);
snprintf(order, 256,"INSERT INTO user(username, password) VALUES('%s','%s')", name.c_str(), pwd.c_str());
LOG_DEBUG( "%s", order);
if(mysql_query(sql, order)) {
LOG_DEBUG( "Insert error!");
flag = false;
}
flag = true;
}
SqlConnPool::Instance()->FreeConn(sql);
LOG_DEBUG( "UserVerify success!!");
return flag;
}
std::string HttpRequest::path() const{
return path_;
}
std::string& HttpRequest::path(){
return path_;
}
std::string HttpRequest::method() const {
return method_;
}
std::string HttpRequest::version() const {
return version_;
}
std::string HttpRequest::GetPost(const std::string& key) const {
assert(key != "");
if(post_.count(key) == 1) {
return post_.find(key)->second;
}
return "";
}
std::string HttpRequest::GetPost(const char* key) const {
assert(key != nullptr);
if(post_.count(key) == 1) {
return post_.find(key)->second;
}
return "";
}
httpresponse.h
class HttpResponse {
public:
HttpResponse();
~HttpResponse();
void Init(const std::string& srcDir, std::string& path, bool isKeepAlive = false, int code = -1);
void MakeResponse(Buffer& buff);
void UnmapFile();
char* File();
size_t FileLen() const;
void ErrorContent(Buffer& buff, std::string message);
int Code() const { return code_; }
private:
void AddStateLine_(Buffer &buff);
void AddHeader_(Buffer &buff);
void AddContent_(Buffer &buff);
void ErrorHtml_();
std::string GetFileType_();
int code_;
bool isKeepAlive_;
std::string path_;
std::string srcDir_;
char* mmFile_;
struct stat mmFileStat_;
static const std::unordered_map<std::string, std::string> SUFFIX_TYPE;
static const std::unordered_map<int, std::string> CODE_STATUS;
static const std::unordered_map<int, std::string> CODE_PATH;
};
httpresponse.cpp
#include "httpresponse.h"
using namespace std;
const unordered_map<string, string> HttpResponse::SUFFIX_TYPE = {
{ ".html", "text/html" },
{ ".xml", "text/xml" },
{ ".xhtml", "application/xhtml+xml" },
{ ".txt", "text/plain" },
};
const unordered_map<int, string> HttpResponse::CODE_STATUS = {
{ 200, "OK" },
{ 400, "Bad Request" },
{ 403, "Forbidden" },
{ 404, "Not Found" },
};
const unordered_map<int, string> HttpResponse::CODE_PATH = {
{ 400, "/400.html" },
{ 403, "/403.html" },
{ 404, "/404.html" },
};
HttpResponse::HttpResponse() {
code_ = -1;
path_ = srcDir_ = "";
isKeepAlive_ = false;
mmFile_ = nullptr;
mmFileStat_ = { 0 };
};
HttpResponse::~HttpResponse() {
UnmapFile();
}
void HttpResponse::Init(const string& srcDir, string& path, bool isKeepAlive, int code){
assert(srcDir != "");
if(mmFile_) { UnmapFile(); }
code_ = code;
isKeepAlive_ = isKeepAlive;
path_ = path;
srcDir_ = srcDir;
mmFile_ = nullptr;
mmFileStat_ = { 0 };
}
void HttpResponse::MakeResponse(Buffer& buff) {
/* 判断请求的资源文件 */
if(stat((srcDir_ + path_).data(), &mmFileStat_) < 0 || S_ISDIR(mmFileStat_.st_mode)) {
code_ = 404;
}
else if(!(mmFileStat_.st_mode & S_IROTH)) {
code_ = 403;
}
else if(code_ == -1) {
code_ = 200;
}
ErrorHtml_();
AddStateLine_(buff);
AddHeader_(buff);
AddContent_(buff);
}
char* HttpResponse::File() {
return mmFile_;
}
size_t HttpResponse::FileLen() const {
return mmFileStat_.st_size;
}
void HttpResponse::ErrorHtml_() {
if(CODE_PATH.count(code_) == 1) {
path_ = CODE_PATH.find(code_)->second;
stat((srcDir_ + path_).data(), &mmFileStat_);
}
}
void HttpResponse::AddStateLine_(Buffer& buff) {
string status;
if(CODE_STATUS.count(code_) == 1) {
status = CODE_STATUS.find(code_)->second;
}
else {
code_ = 400;
status = CODE_STATUS.find(400)->second;
}
buff.Append("HTTP/1.1 " + to_string(code_) + " " + status + "\r\n");
}
void HttpResponse::AddHeader_(Buffer& buff) {
buff.Append("Connection: ");
if(isKeepAlive_) {
buff.Append("keep-alive\r\n");
buff.Append("keep-alive: max=6, timeout=120\r\n");
} else{
buff.Append("close\r\n");
}
buff.Append("Content-type: " + GetFileType_() + "\r\n");
}
void HttpResponse::AddContent_(Buffer& buff) {
int srcFd = open((srcDir_ + path_).data(), O_RDONLY);
if(srcFd < 0) {
ErrorContent(buff, "File NotFound!");
return;
}
/* 将文件映射到内存提高文件的访问速度
MAP_PRIVATE 建立一个写入时拷贝的私有映射*/
LOG_DEBUG("file path %s", (srcDir_ + path_).data());
int* mmRet = (int*)mmap(0, mmFileStat_.st_size, PROT_READ, MAP_PRIVATE, srcFd, 0);
if(*mmRet == -1) {
ErrorContent(buff, "File NotFound!");
return;
}
mmFile_ = (char*)mmRet;
close(srcFd);
buff.Append("Content-length: " + to_string(mmFileStat_.st_size) + "\r\n\r\n");
}
void HttpResponse::UnmapFile() {
if(mmFile_) {
munmap(mmFile_, mmFileStat_.st_size);
mmFile_ = nullptr;
}
}
string HttpResponse::GetFileType_() {
/* 判断文件类型 */
string::size_type idx = path_.find_last_of('.');
if(idx == string::npos) {
return "text/plain";
}
string suffix = path_.substr(idx);
if(SUFFIX_TYPE.count(suffix) == 1) {
return SUFFIX_TYPE.find(suffix)->second;
}
return "text/plain";
}
void HttpResponse::ErrorContent(Buffer& buff, string message)
{
string body;
string status;
body += "<html><title>Error</title>";
body += "<body bgcolor=\"ffffff\">";
if(CODE_STATUS.count(code_) == 1) {
status = CODE_STATUS.find(code_)->second;
} else {
status = "Bad Request";
}
body += to_string(code_) + " : " + status + "\n";
body += "<p>" + message + "</p>";
body += "<hr><em>TinyWebServer</em></body></html>";
buff.Append("Content-length: " + to_string(body.size()) + "\r\n\r\n");
buff.Append(body);
}
threadpool.h
class ThreadPool {
public:
explicit ThreadPool(size_t threadCount = 8): pool_(std::make_shared<Pool>()) {
assert(threadCount > 0);
for(size_t i = 0; i < threadCount; i++) {
std::thread([pool = pool_] {
std::unique_lock<std::mutex> locker(pool->mtx);
while(true) {
if(!pool->tasks.empty()) {
auto task = std::move(pool->tasks.front());
pool->tasks.pop();
locker.unlock();
task();
locker.lock();
}
else if(pool->isClosed) break;
else pool->cond.wait(locker);
}
}).detach();
}
}
ThreadPool() = default;
ThreadPool(ThreadPool&&) = default;
~ThreadPool() {
if(static_cast<bool>(pool_)) {
{
std::lock_guard<std::mutex> locker(pool_->mtx);
pool_->isClosed = true;
}
pool_->cond.notify_all();
}
}
template<class F>
void AddTask(F&& task) {
{
std::lock_guard<std::mutex> locker(pool_->mtx);
pool_->tasks.emplace(std::forward<F>(task));
}
pool_->cond.notify_one();
}
private:
struct Pool {
std::mutex mtx;
std::condition_variable cond;
bool isClosed;
std::queue<std::function<void()>> tasks;
};
std::shared_ptr<Pool> pool_;
};
sqlconnRAII.h
class SqlConnRAII {
public:
SqlConnRAII(MYSQL** sql, SqlConnPool *connpool) {
assert(connpool);
*sql = connpool->GetConn();
sql_ = *sql;
connpool_ = connpool;
}
~SqlConnRAII() {
if(sql_) { connpool_->FreeConn(sql_); }
}
private:
MYSQL *sql_;
SqlConnPool* connpool_;
};
sqlconnpool.h
class SqlConnPool {
public:
static SqlConnPool *Instance();
MYSQL *GetConn();
void FreeConn(MYSQL * conn);
int GetFreeConnCount();
void Init(const char* host, int port,
const char* user,const char* pwd,
const char* dbName, int connSize);
void ClosePool();
private:
SqlConnPool();
~SqlConnPool();
int MAX_CONN_;
int useCount_;
int freeCount_;
std::queue<MYSQL *> connQue_;
std::mutex mtx_;
sem_t semId_;
};
sqlconnpool.cpp
#include "sqlconnpool.h"
using namespace std;
SqlConnPool::SqlConnPool() {
useCount_ = 0;
freeCount_ = 0;
}
SqlConnPool* SqlConnPool::Instance() {
static SqlConnPool connPool;
return &connPool;
}
void SqlConnPool::Init(const char* host, int port,
const char* user,const char* pwd, const char* dbName,
int connSize = 10) {
assert(connSize > 0);
for (int i = 0; i < connSize; i++) {
MYSQL *sql = nullptr;
sql = mysql_init(sql);
if (!sql) {
LOG_ERROR("MySql init error!");
assert(sql);
}
sql = mysql_real_connect(sql, host,
user, pwd,
dbName, port, nullptr, 0);
if (!sql) {
LOG_ERROR("MySql Connect error!");
}
connQue_.push(sql);
}
MAX_CONN_ = connSize;
sem_init(&semId_, 0, MAX_CONN_);
}
MYSQL* SqlConnPool::GetConn() {
MYSQL *sql = nullptr;
if(connQue_.empty()){
LOG_WARN("SqlConnPool busy!");
return nullptr;
}
sem_wait(&semId_);
{
lock_guard<mutex> locker(mtx_);
sql = connQue_.front();
connQue_.pop();
}
return sql;
}
void SqlConnPool::FreeConn(MYSQL* sql) {
assert(sql);
lock_guard<mutex> locker(mtx_);
connQue_.push(sql);
sem_post(&semId_);
}
void SqlConnPool::ClosePool() {
lock_guard<mutex> locker(mtx_);
while(!connQue_.empty()) {
auto item = connQue_.front();
connQue_.pop();
mysql_close(item);
}
mysql_library_end();
}
int SqlConnPool::GetFreeConnCount() {
lock_guard<mutex> locker(mtx_);
return connQue_.size();
}
SqlConnPool::~SqlConnPool() {
ClosePool();
}
webserver.h
class WebServer {
public:
WebServer(
int port, int trigMode, int timeoutMS, bool OptLinger,
int sqlPort, const char* sqlUser, const char* sqlPwd,
const char* dbName, int connPoolNum, int threadNum,
bool openLog, int logLevel, int logQueSize);
~WebServer();
void Start();
private:
bool InitSocket_();
void InitEventMode_(int trigMode);
void AddClient_(int fd, sockaddr_in addr);
void DealListen_();
void DealWrite_(HttpConn* client);
void DealRead_(HttpConn* client);
void SendError_(int fd, const char*info);
void ExtentTime_(HttpConn* client);
void CloseConn_(HttpConn* client);
void OnRead_(HttpConn* client);
void OnWrite_(HttpConn* client);
void OnProcess(HttpConn* client);
static const int MAX_FD = 65536;
static int SetFdNonblock(int fd);
int port_;
bool openLinger_;
int timeoutMS_; /* 毫秒MS */
bool isClose_;
int listenFd_;
char* srcDir_;
uint32_t listenEvent_;
uint32_t connEvent_;
std::unique_ptr<HeapTimer> timer_;
std::unique_ptr<ThreadPool> threadpool_;
std::unique_ptr<Epoller> epoller_;
std::unordered_map<int, HttpConn> users_;
};
webserver.cpp
#include "webserver.h"
using namespace std;
WebServer::WebServer(
int port, int trigMode, int timeoutMS, bool OptLinger,
int sqlPort, const char* sqlUser, const char* sqlPwd,
const char* dbName, int connPoolNum, int threadNum,
bool openLog, int logLevel, int logQueSize):
port_(port), openLinger_(OptLinger), timeoutMS_(timeoutMS), isClose_(false),
timer_(new HeapTimer()), threadpool_(new ThreadPool(threadNum)), epoller_(new Epoller())
{
srcDir_ = getcwd(nullptr, 256);
assert(srcDir_);
strncat(srcDir_, "/resources/", 16);
HttpConn::userCount = 0;
HttpConn::srcDir = srcDir_;
SqlConnPool::Instance()->Init("localhost", sqlPort, sqlUser, sqlPwd, dbName, connPoolNum);
InitEventMode_(trigMode);
if(!InitSocket_()) { isClose_ = true;}
if(openLog) {
Log::Instance()->init(logLevel, "./log", ".log", logQueSize);
if(isClose_) { LOG_ERROR("========== Server init error!=========="); }
else {
LOG_INFO("========== Server init ==========");
LOG_INFO("Port:%d, OpenLinger: %s", port_, OptLinger? "true":"false");
LOG_INFO("Listen Mode: %s, OpenConn Mode: %s",
(listenEvent_ & EPOLLET ? "ET": "LT"),
(connEvent_ & EPOLLET ? "ET": "LT"));
LOG_INFO("LogSys level: %d", logLevel);
LOG_INFO("srcDir: %s", HttpConn::srcDir);
LOG_INFO("SqlConnPool num: %d, ThreadPool num: %d", connPoolNum, threadNum);
}
}
}
WebServer::~WebServer() {
close(listenFd_);
isClose_ = true;
free(srcDir_);
SqlConnPool::Instance()->ClosePool();
}
void WebServer::InitEventMode_(int trigMode) {
listenEvent_ = EPOLLRDHUP;
connEvent_ = EPOLLONESHOT | EPOLLRDHUP;
switch (trigMode)
{
case 0:
break;
case 1:
connEvent_ |= EPOLLET;
break;
case 2:
listenEvent_ |= EPOLLET;
break;
case 3:
listenEvent_ |= EPOLLET;
connEvent_ |= EPOLLET;
break;
default:
listenEvent_ |= EPOLLET;
connEvent_ |= EPOLLET;
break;
}
HttpConn::isET = (connEvent_ & EPOLLET);
}
void WebServer::Start() {
int timeMS = -1; /* epoll wait timeout == -1 无事件将阻塞 */
if(!isClose_) { LOG_INFO("========== Server start =========="); }
while(!isClose_) {
if(timeoutMS_ > 0) {
timeMS = timer_->GetNextTick();
}
int eventCnt = epoller_->Wait(timeMS);
for(int i = 0; i < eventCnt; i++) {
/* 处理事件 */
int fd = epoller_->GetEventFd(i);
uint32_t events = epoller_->GetEvents(i);
if(fd == listenFd_) {
DealListen_();
}
else if(events & (EPOLLRDHUP | EPOLLHUP | EPOLLERR)) {
assert(users_.count(fd) > 0);
CloseConn_(&users_[fd]);
}
else if(events & EPOLLIN) {
assert(users_.count(fd) > 0);
DealRead_(&users_[fd]);
}
else if(events & EPOLLOUT) {
assert(users_.count(fd) > 0);
DealWrite_(&users_[fd]);
} else {
LOG_ERROR("Unexpected event");
}
}
}
}
void WebServer::SendError_(int fd, const char*info) {
assert(fd > 0);
int ret = send(fd, info, strlen(info), 0);
if(ret < 0) {
LOG_WARN("send error to client[%d] error!", fd);
}
close(fd);
}
void WebServer::CloseConn_(HttpConn* client) {
assert(client);
LOG_INFO("Client[%d] quit!", client->GetFd());
epoller_->DelFd(client->GetFd());
client->Close();
}
void WebServer::AddClient_(int fd, sockaddr_in addr) {
assert(fd > 0);
users_[fd].init(fd, addr);
if(timeoutMS_ > 0) {
timer_->add(fd, timeoutMS_, std::bind(&WebServer::CloseConn_, this, &users_[fd]));
}
epoller_->AddFd(fd, EPOLLIN | connEvent_);
SetFdNonblock(fd);
LOG_INFO("Client[%d] in!", users_[fd].GetFd());
}
void WebServer::DealListen_() {
struct sockaddr_in addr;
socklen_t len = sizeof(addr);
do {
int fd = accept(listenFd_, (struct sockaddr *)&addr, &len);
if(fd <= 0) { return;}
else if(HttpConn::userCount >= MAX_FD) {
SendError_(fd, "Server busy!");
LOG_WARN("Clients is full!");
return;
}
AddClient_(fd, addr);
} while(listenEvent_ & EPOLLET);
}
void WebServer::DealRead_(HttpConn* client) {
assert(client);
ExtentTime_(client);
threadpool_->AddTask(std::bind(&WebServer::OnRead_, this, client));
}
void WebServer::DealWrite_(HttpConn* client) {
assert(client);
ExtentTime_(client);
threadpool_->AddTask(std::bind(&WebServer::OnWrite_, this, client));
}
void WebServer::ExtentTime_(HttpConn* client) {
assert(client);
if(timeoutMS_ > 0) { timer_->adjust(client->GetFd(), timeoutMS_); }
}
void WebServer::OnRead_(HttpConn* client) {
assert(client);
int ret = -1;
int readErrno = 0;
ret = client->read(&readErrno);
if(ret <= 0 && readErrno != EAGAIN) {
CloseConn_(client);
return;
}
OnProcess(client);
}
void WebServer::OnProcess(HttpConn* client) {
if(client->process()) {
epoller_->ModFd(client->GetFd(), connEvent_ | EPOLLOUT);
} else {
epoller_->ModFd(client->GetFd(), connEvent_ | EPOLLIN);
}
}
void WebServer::OnWrite_(HttpConn* client) {
assert(client);
int ret = -1;
int writeErrno = 0;
ret = client->write(&writeErrno);
if(client->ToWriteBytes() == 0) {
/* 传输完成 */
if(client->IsKeepAlive()) {
OnProcess(client);
return;
}
}
else if(ret < 0) {
if(writeErrno == EAGAIN) {
/* 继续传输 */
epoller_->ModFd(client->GetFd(), connEvent_ | EPOLLOUT);
return;
}
}
CloseConn_(client);
}
/* Create listenFd */
bool WebServer::InitSocket_() {
int ret;
struct sockaddr_in addr;
if(port_ > 65535 || port_ < 1024) {
LOG_ERROR("Port:%d error!", port_);
return false;
}
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = htons(port_);
struct linger optLinger = { 0 };
if(openLinger_) {
/* 优雅关闭: 直到所剩数据发送完毕或超时 */
optLinger.l_onoff = 1;
optLinger.l_linger = 1;
}
listenFd_ = socket(AF_INET, SOCK_STREAM, 0);
if(listenFd_ < 0) {
LOG_ERROR("Create socket error!", port_);
return false;
}
ret = setsockopt(listenFd_, SOL_SOCKET, SO_LINGER, &optLinger, sizeof(optLinger));
if(ret < 0) {
close(listenFd_);
LOG_ERROR("Init linger error!", port_);
return false;
}
int optval = 1;
/* 端口复用 */
/* 只有最后一个套接字会正常接收数据。 */
ret = setsockopt(listenFd_, SOL_SOCKET, SO_REUSEADDR, (const void*)&optval, sizeof(int));
if(ret == -1) {
LOG_ERROR("set socket setsockopt error !");
close(listenFd_);
return false;
}
ret = bind(listenFd_, (struct sockaddr *)&addr, sizeof(addr));
if(ret < 0) {
LOG_ERROR("Bind Port:%d error!", port_);
close(listenFd_);
return false;
}
ret = listen(listenFd_, 6);
if(ret < 0) {
LOG_ERROR("Listen port:%d error!", port_);
close(listenFd_);
return false;
}
ret = epoller_->AddFd(listenFd_, listenEvent_ | EPOLLIN);
if(ret == 0) {
LOG_ERROR("Add listen error!");
close(listenFd_);
return false;
}
SetFdNonblock(listenFd_);
LOG_INFO("Server port:%d", port_);
return true;
}
int WebServer::SetFdNonblock(int fd) {
assert(fd > 0);
return fcntl(fd, F_SETFL, fcntl(fd, F_GETFD, 0) | O_NONBLOCK);
}
main.cpp
int main() {
/* 守护进程 后台运行 */
//daemon(1, 0);
WebServer server(
1316, 3, 60000, false, /* 端口 ET模式 timeoutMs 优雅退出 */
3306, "root", "root", "webserver", /* Mysql配置 */
12, 6, true, 1, 1024); /* 连接池数量 线程池数量 日志开关 日志等级 日志异步队列容量 */
server.Start();
}