代码主要逻辑:
- 发送数据:
push
函数根据窗口大小和待发送数据的情况,发送数据段(包括处理初始的 SYN、payload 和 FIN)。 - 接收 ACK:
receive
函数处理从接收端接收到的 ACK,更新窗口大小、确认号等,并释放已确认的数据段。 - 重传逻辑:
tick
函数处理定时器,检查是否需要重传未确认段,并调整重传超时时间。 - 处理特殊情况:包括零窗口探测、处理初始 SYN、处理 FIN 段、调整重传超时等。
总的来说,这段代码实现了一个 TCP 协议发送端的核心功能,负责管理数据发送、重传、窗口大小调整等操作,确保数据能够可靠地传输到接收端。
tcp_sender.hh :
#pragma once
#include "byte_stream.hh"
#include "tcp_receiver_message.hh"
#include "tcp_sender_message.hh"
#include <cstdint>
#include <functional>
#include <list>
#include <memory>
#include <optional>
#include <queue>
#include <map>
class TCPSender
{
public:
/* 用给定的默认重传超时时间和可能的初始序列号构造TCP发送者 */
TCPSender( ByteStream&& input, Wrap32 isn, uint64_t initial_RTO_ms )
: input_( std::move( input ) ), isn_( isn ), initial_RTO_ms_( initial_RTO_ms ),raw_RTO_ms(initial_RTO_ms)
,currentSeqNum_(isn),last_Ack_Seq(isn)
,window_size_(2)
,unAckedSegments()
{}
/* 生成一个空的TCPSenderMessage */
TCPSenderMessage make_empty_message() const;
/* 接收并处理来自对端接收者的TCPReceiverMessage */
void receive( const TCPReceiverMessage& msg );
/* 定义`transmit`函数的类型,该函数用于push和tick方法发送消息 */
using TransmitFunction = std::function<void( const TCPSenderMessage& )>;
/* 从输出流中推送字节 */
void push( const TransmitFunction& transmit );
/* 自上次调用tick()方法以来,时间已经过去了指定的毫秒数 */
void tick( uint64_t ms_since_last_tick, const TransmitFunction& transmit );
// 访问器
uint64_t sequence_numbers_in_flight() const; // 当前有多少序列号未确认?
uint64_t consecutive_retransmissions() const; // 发生了多少次连续的重传?
Writer& writer() { return input_.writer(); }
const Writer& writer() const { return input_.writer(); }
// 只读访问输入流的读取器(外部不能读取)
const Reader& reader() const { return input_.reader(); }
private:
// push:
// “窗口探测”
void handleWindowProbe(const TransmitFunction& transmit);
// 处理SYN头
bool handleInitialSYN(TCPSenderMessage& message);
// 处理分段payload
void handlePayload(TCPSenderMessage& message);
// 处理分段序列号
void handleSqeno(TCPSenderMessage& message);
// 处理分段FIN
bool handleFIN(TCPSenderMessage& message);
// 重新设置RTO
void resetRTO();
// Receive:
// 处理返回的ACK
void processACK(const TCPReceiverMessage& msg);
// 处理已经ack数据分段
void handle_ack();
// Tick:
// TCP 使用指数退避策略来调整重传超时时间
void handle_RTO();
// 构造函数中初始化的变量
ByteStream input_; // 输入流
Wrap32 isn_; // 初始序列号
uint64_t initial_RTO_ms_; // 重传超时时间(毫秒)
uint64_t raw_RTO_ms; // 初始重传超时时间(毫秒
Wrap32 currentSeqNum_; // 当前发送数据分段的序列号
std::optional<Wrap32> last_Ack_Seq; // 上一个发送数据分段的序列号
uint16_t window_size_; // 当前接收方的窗口大小
// 记录未确认的分段
std::map<uint64_t,TCPSenderMessage> unAckedSegments;
bool is_SYN_ACK = false; // 记录窗口是否確定
uint64_t unAckedSegmentsNums = 0; // 当前待确认的字节数
uint64_t checkout = 0; // 当前已经ack的绝对序列号
uint64_t push_checkout = 0; // 当前已经push的绝对序列号
uint64_t since_last_send = 0; // 记录上次send的时间
bool is_RTO_double = false; // 记录非零窗口是否需要退避RTO(RTO增加)
bool isSYNSent_= false; // 判断是否发送过SYN
bool isFINSent_= false; // 判断是否发送过FIN
};
void print(TCPSenderMessage message);
tcp_sender.cc :
#include "tcp_sender.hh"
#include "tcp_config.hh"
#include<iostream>
using namespace std;
uint64_t TCPSender::sequence_numbers_in_flight() const
{
return unAckedSegmentsNums;
}
uint64_t TCPSender::consecutive_retransmissions() const
{
uint64_t exponent = 0;
uint64_t number = initial_RTO_ms_ / raw_RTO_ms;
while (number > 1) {
number /= 2;
exponent++;
}
return exponent + is_RTO_double;
}
TCPSenderMessage TCPSender::make_empty_message() const
{
TCPSenderMessage message;
message.FIN = false;
message.RST = input_.has_error();
message.SYN = false;
message.payload = "";
message.seqno = currentSeqNum_ ;
return message;
}
void TCPSender::push( const TransmitFunction& transmit )
{
uint64_t windowSize = window_size_==0?1:window_size_;
// 若ByteStream更新新字节,构造发送信息
uint64_t bytes_to_send = input_.reader().bytes_buffered(); // 总共需要发送的字节长度
uint64_t payload_len = min({input_.reader().bytes_buffered()
, static_cast<uint64_t>(TCPConfig::MAX_PAYLOAD_SIZE)
, static_cast<uint64_t>(windowSize) - sequence_numbers_in_flight()});
TCPSenderMessage message =make_empty_message();
// 处理SYN头
if(!handleInitialSYN(message)){
return;
}
do{
if(static_cast<uint64_t>(windowSize) - sequence_numbers_in_flight() == 0){
return ;
}
if(message.RST){
transmit(message);
return;
}
// 处理payload
handlePayload(message);
// 处理FIN
if(handleFIN(message)){
return ;
}
// 处理序列号
handleSqeno(message);
unAckedSegments[push_checkout] = message;
// 增加未确认的分段数量
push_checkout += payload_len;
// 更新分段信息
bytes_to_send -= payload_len;
// 更新
is_RTO_double = 0;
// 发送信息
transmit(message);
}while(bytes_to_send >0);
}
void TCPSender::receive( const TCPReceiverMessage& msg )
{
if (!msg.ackno.has_value()) {
if(!msg.window_size){
input_.set_error();
}
return;
}
if(msg.RST){
input_.set_error();
}
// 无效ACK
if(msg.ackno > currentSeqNum_ ){
return;
}
// 重复ACK
if(last_Ack_Seq.has_value()){
if(last_Ack_Seq >= msg.ackno && window_size_ == msg.window_size){
return;
}
}
is_SYN_ACK = true;
last_Ack_Seq = msg.ackno.value();
// 更新当前确定的接收绝对序列号
checkout = msg.ackno.value().unwrap(isn_,checkout);
// 释放已经缓冲区已经ack的数据
handle_ack();
// 更新窗口大小
window_size_ = msg.window_size;
// 更新待确认的数据个数
unAckedSegmentsNums = currentSeqNum_.distance(msg.ackno.value());
// 重置RTO翻倍数据
resetRTO();
return;
}
void TCPSender::tick(uint64_t ms_since_last_tick, const TransmitFunction& transmit)
{
since_last_send += ms_since_last_tick;
// 不会因为连续的零窗口确认而使 RTO 退避(不增加 RTO)。
// 这种行为是为了维持连接和测试窗口是否已重新打开,而不是因为网络拥堵。
if(window_size_ != 0 ){
handle_RTO();
}
// 检查是否达到初始 RTO
if (since_last_send >= initial_RTO_ms_) {
since_last_send = 0;
is_RTO_double = true;
// 遍历未确认段,并传输每个段
if (!unAckedSegments.empty()) {
transmit(unAckedSegments.begin()->second);
}
}
}
// 进行“窗口探测”
void TCPSender::handleWindowProbe(const TransmitFunction& transmit) {
if (!window_size_) {
TCPSenderMessage message = make_empty_message();
// 处理payload
handlePayload(message);
// 处理FIN
if(handleFIN(message)){
return ;
}
// 处理序列号
handleSqeno(message);
print(message);
transmit(message);
}
}
// 处理SYN头
bool TCPSender::handleInitialSYN(TCPSenderMessage& message) {
// 流中无字节,且未结束传输
if (isSYNSent_ && !input_.reader().bytes_buffered() && !input_.writer().is_closed()) {
return false;
}
//window_size还未设置,SYN已经设置
if(!is_SYN_ACK && isSYNSent_){
return false;
}
if (!isSYNSent_) {
message.SYN = true;
isSYNSent_ = true;
}
return true;
}
// 处理分段FIN
bool TCPSender::handleFIN(TCPSenderMessage& message){
if(isFINSent_){
return true;
}
if (input_.writer().is_closed() &&
!input_.reader().bytes_buffered() &&
(window_size_ == 0 ? 1 : window_size_) - message.sequence_length() > 0){
message.FIN = true;
isFINSent_ = true;
}
return false;
}
// 处理分段payload
void TCPSender::handlePayload(TCPSenderMessage& message){
uint64_t payload_len = min({input_.reader().bytes_buffered()
, static_cast<uint64_t>(TCPConfig::MAX_PAYLOAD_SIZE)
, static_cast<uint64_t>(window_size_ == 0?1:window_size_) - sequence_numbers_in_flight()});
// 处理分段的payload
message.payload = std::string(input_.reader().peek().substr(0, payload_len));
input_.reader().pop(payload_len);
}
// 处理分段序列号
void TCPSender::handleSqeno(TCPSenderMessage& message){
// 修改当前分段序列号
message.seqno = currentSeqNum_;
// 分段的序列号
currentSeqNum_ = currentSeqNum_ + message.sequence_length();
// 待确认的序列号数量
unAckedSegmentsNums += message.sequence_length();
}
// 重新设置RTO
void TCPSender::resetRTO() {
is_RTO_double = false;
initial_RTO_ms_ = raw_RTO_ms;
since_last_send = 0;
}
// 处理返回的ACK
void TCPSender::processACK(const TCPReceiverMessage& msg) {
if (!msg.ackno.has_value()) {
return;
}
if (msg.RST) {
input_.set_error();
}
// 无效ACK
if (msg.ackno > currentSeqNum_) {
return;
}
// 重复ACK
if (last_Ack_Seq.has_value()) {
if (last_Ack_Seq >= msg.ackno && window_size_ == msg.window_size) {
return;
}
}
}
// 处理已经ack数据分段
void TCPSender::handle_ack() {
auto it = unAckedSegments.begin();
while (it != unAckedSegments.end()) {
uint64_t seq_no = it->first;
uint64_t end_seq_no = seq_no + it->second.sequence_length();
if (end_seq_no < checkout) {
it = unAckedSegments.erase(it);
} else {
++it;
}
}
}
// 调整重传超时时间
void TCPSender::handle_RTO(){
if(is_RTO_double){
initial_RTO_ms_ *=2;
is_RTO_double = false;
}
}
void print(TCPSenderMessage message){
std::cout << "Current Sequence Number: " << message.seqno.getuint32_t() << std::endl;
std::cout << "SYN: " << (message.SYN ? "true" : "false") << std::endl;
std::cout << "payload: " << message.payload << std::endl;
std::cout << "FIN: " << (message.FIN ? "true" : "false") << std::endl;
std::cout << "RST: " << (message.RST ? "true" : "false") << std::endl;
std::cout << "sequence_length: " << message.sequence_length() << std::endl;
}