From 65c2d505caff1a7e138937006db7d559dcb7917e Mon Sep 17 00:00:00 2001 From: chenzizhan Date: Thu, 14 Sep 2023 16:42:45 +0800 Subject: wip --- src/session/tcp_reassembly.rs | 1419 +++++++------------ .../tcp_reassembly_with_status_deprecated.rs | 1482 ++++++++++++++++++++ 2 files changed, 2013 insertions(+), 888 deletions(-) create mode 100644 src/session/tcp_reassembly_with_status_deprecated.rs diff --git a/src/session/tcp_reassembly.rs b/src/session/tcp_reassembly.rs index c18b766..0ceedde 100644 --- a/src/session/tcp_reassembly.rs +++ b/src/session/tcp_reassembly.rs @@ -1,138 +1,37 @@ use std::collections::VecDeque; -use std::fmt; +use std::f32::consts::E; use std::net::{Ipv4Addr}; use std::num::Wrapping; -use std::vec::IntoIter; -use nom::Err; -use super::duration::Duration; use crate::protocol::ipv4::IPv4Header; use crate::protocol::ipv6::IPv6Header; use crate::protocol::udp::UdpHeader; use crate::protocol::ethernet::EthernetFrame; -use crate::protocol::tcp::{TcpHeader, TcpOption, self}; +use crate::protocol::tcp::{TcpHeader}; use crate::protocol::dns::DNS_MESSAGE; use crate::protocol::http::HTTP_MESSAGE; use crate::packet::packet::Encapsulation; use crate::packet::packet::Packet as RawPacket; -const DEFAULT_TIMEOUT: Duration = Duration{secs:7200, micros:0}; // 120 min timeout, currently do not support option 28 parse -// todo: const TCP_OPTION_TIMESTAMPS: u8 = 28; https://datatracker.ietf.org/doc/rfc5482/ -const DEFAULT_MAX_PACKETS: usize = 2048; - -// todo: 装的packet太多怎么办 - - -#[derive(Debug, Eq, PartialEq)] -pub enum TcpSessionErr { - PacketNotV4Tcp, - WrongFlags, // todo返回当前的会话状态,以及期望的flag, 为此: - // todo: TCP header to flag:u16 - ExpiredSession, - /// the seq number is wrong in the 2nd or 3rd handshake. - HandshakeFailed, - /// new a connection failed, because the packet is not a SYN packet, but the packet is valid - NewConnectionFailed, - /// ack number is even higher than the next expected seq number - UnexpectedAckNumber, - /// The packet itself is valid, but the window size is not enough to hold the packet. Tcp state machine will change as usual, while the packet payload is discarded. - SidePeerWindowFull, -} - -#[allow(non_snake_case)] -pub mod TcpSessionOk { // use as enum - pub const ESTABLISHED: u32 = 0x01; - pub const CLOSING: u32 = 0x02; - pub const CLOSED: u32 = 0x04; - pub const ACK_SEGMENT: u32 = 0x08; - pub const PACKET_IN_PAST: u32 = 0x10; - pub const TOO_MANY_PACKET_WARNING: u32 = 0x20; - pub const TOO_MANY_PACKETS: u32 = 0x40; - pub const OTHER: u32 = 0x0; -} - -#[derive(Debug, Eq, PartialEq, Clone)] -#[allow(dead_code)] -pub enum TcpStatus { - Closed = 0, - Listen, - SynSent, - SynRcv, - Established, - Closing, - CloseWait, - FinWait1, - FinWait2, - LastAck, - TimeWait, -} - -impl Default for TcpStatus { - fn default() -> Self { - TcpStatus::Closed - } -} - -enum TcpFlags { - FIN = 0x01, - SYN = 0x02, - RST = 0x04, - PSH = 0x08, - ACK = 0x10, - URG = 0x20, -} - -/* -------------------------------------------------------------------------- */ -/* iter */ -/* -------------------------------------------------------------------------- */ - -pub struct TcpIterator<'a> { - segments: IntoIter<(u32, &'a CopiedRawPacket)>, -} - -impl<'a> Iterator for TcpIterator<'a> { - type Item = (u32, RawPacket<'a>); - - fn next(&mut self) -> Option { - while let Some((index, packet)) = self.segments.next() { - let mut ret_encap = Vec::new(); - for encap in &packet.encapsulation { - match encap { - CopiedEncapsulation::L2_ETH(l2, seg) => { - ret_encap.push(Encapsulation::L2_ETH(l2.clone(), seg.as_slice())); - } - CopiedEncapsulation::L3_IP4(ipv4, seg) => { - ret_encap.push(Encapsulation::L3_IP4(ipv4.clone(), seg.as_slice())); - } - CopiedEncapsulation::L3_IP6(ipv6, seg) => { - ret_encap.push(Encapsulation::L3_IP6(ipv6.clone(), seg.as_slice())); - } - CopiedEncapsulation::L4_TCP(tcp, seg) => { - ret_encap.push(Encapsulation::L4_TCP(tcp.clone(), seg.as_slice())); - } - CopiedEncapsulation::L4_UDP(udp, seg) => { - ret_encap.push(Encapsulation::L4_UDP(udp.clone(), seg.as_slice())); - } - CopiedEncapsulation::L7_DNS(dns, seg) => { - ret_encap.push(Encapsulation::L7_DNS(dns.clone(), seg.as_slice())); - } - CopiedEncapsulation::L7_HTTP(http, seg) => { - ret_encap.push(Encapsulation::L7_HTTP(http.clone(), seg.as_slice())); - } - CopiedEncapsulation::Unsupported(seg) => { - ret_encap.push(Encapsulation::Unsupported(seg.as_slice())); - } - } - } - - return Some((index, RawPacket { - encapsulation: ret_encap, - orig_data: packet.orig_data.as_slice(), - orig_len: packet.orig_len, - })); - } - None - } +const DEFAULT_MAX_PACKETS: usize = 128; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum TcpSegmentDescription { + // has packet + Normal, + TooManyPacket, + FinTrigger, + + // no packet + Unordered, + DuplicateSeq, + OldPacket, + NoSegment, + NotIp4Tcp, + + HandshakeFail(String), + SynAckOk, + Reopen, // todo } // since the pub encapsulation has many reference of the original packet buffer, we have to copy them first @@ -150,13 +49,46 @@ enum CopiedEncapsulation { } #[derive(Debug, Clone)] -struct CopiedRawPacket { +pub(crate) struct CopiedRawPacket { encapsulation: Vec, orig_data: Vec, orig_len: u32, } +impl CopiedRawPacket { + fn header(&self) -> TcpHeader { + for encapsulation in &self.encapsulation { + match encapsulation { + CopiedEncapsulation::L4_TCP(header, _) => return header.clone(), + _ => {} + } + } + panic!("not a tcp packet"); + } + fn payload(&self) -> &[u8] { + for encapsulation in &self.encapsulation { + match encapsulation { + CopiedEncapsulation::L4_TCP(_, payload) => return payload.as_slice(), + _ => {} + } + } + panic!("not a tcp packet"); + } + fn replace_payload(&mut self, payload: Vec) { + for encapsulation in &mut self.encapsulation { + match encapsulation { + CopiedEncapsulation::L4_TCP(_, p) => { + *p = payload; + return; + } + _ => {} + } + } + panic!("not a tcp packet"); + } +} + impl From> for CopiedEncapsulation { fn from(encap: Encapsulation<'_>) -> Self { match encap { @@ -182,7 +114,7 @@ impl From<&RawPacket<'_>> for CopiedRawPacket { } } -fn raw_packet_convert_to_my_packet(raw_packet: &RawPacket<'_>) -> Result { +fn raw_packet_convert_to_my_packet(raw_packet: &RawPacket<'_>) -> Result { let mut payload = Vec::new(); let mut ipv4_header = Option::None; let mut tcp_header = Option::None; @@ -199,14 +131,17 @@ fn raw_packet_convert_to_my_packet(raw_packet: &RawPacket<'_>) -> Result, src_ip: Ipv4Addr, dst_ip: Ipv4Addr, - tcp_header : TcpHeader, + src_port: u16, + dst_port: u16, + seq_num: u32, + ack_num: u32, raw_packet: CopiedRawPacket, } +enum TcpFlags { + FIN = 0x01, + SYN = 0x02, + RST = 0x04, + PSH = 0x08, + ACK = 0x10, + URG = 0x20, +} + impl TcpPacket { fn get_sequence(&self) -> u32 { - self.tcp_header.seq_num + self.seq_num } fn get_acknowledgement(&self) -> u32 { - self.tcp_header.ack_num - } - fn has_flag(&self, flag: TcpFlags) -> bool { - match flag { - TcpFlags::URG => self.tcp_header.flag_urg, - TcpFlags::ACK => self.tcp_header.flag_ack, - TcpFlags::PSH => self.tcp_header.flag_psh, - TcpFlags::RST => self.tcp_header.flag_rst, - TcpFlags::SYN => self.tcp_header.flag_syn, - TcpFlags::FIN => self.tcp_header.flag_fin, - } + self.ack_num } fn payload(&self) -> &[u8] { self.payload.as_slice() } - fn get_timestamp(&self) -> Duration { // todo: 感觉这个duration没什么用,改成u32 吧,单位是秒 - let mut time_val:u32 = 0; - if let Some(options) = &self.tcp_header.options { - for option in options { - if let TcpOption::TIMESTAMPS{length:_, ts_value, ts_reply:_} = option { - time_val = *ts_value; - } - } + fn has_flag(&self, flag: TcpFlags) -> bool { + let header = self.raw_packet.header(); + match flag { + TcpFlags::URG => header.flag_urg, + TcpFlags::ACK => header.flag_ack, + TcpFlags::PSH => header.flag_psh, + TcpFlags::RST => header.flag_rst, + TcpFlags::SYN => header.flag_syn, + TcpFlags::FIN => header.flag_fin, } - Duration::new(time_val, 0) } } #[derive(Debug)] struct TcpSegment { - rel_seq: Wrapping, - rel_ack: Wrapping, + rel_seq: Wrapping, // todo: wrapping 主要是解决回绕问题https://blog.csdn.net/LU_ZHAO/article/details/105010778.当前没有实现,不过最好测一下 + +// todo: 带回绕的实际值查询、加减、设置和判断,难点是判断,来了一个新的seq number,我可能就要试一下是不是在回绕的范围内,如果把它当成回绕量,判断发现比上一个seq number 大,且大得很有限,就处理为回绕。 +// 看看其他代码怎么处理的. +// 注意输出到raw packet的时候还要再转一下。 payload: Vec, tcp_header: TcpHeader, raw_packet: CopiedRawPacket, + rel_ack: Wrapping, // todo: 干掉它 } impl TcpSegment { @@ -300,24 +241,18 @@ impl TcpSegment { } } +#[derive(Debug)] struct TcpPeer { // Initial Seq number (absolute) isn: Wrapping, // Initial Ack number (absolute) ian: Wrapping, - // Next Seq number + // Next Seq number, isn + (sum of all sent segments lengths) next_rel_seq: Wrapping, - // Last acknowledged number - last_rel_ack: Wrapping, - // Connection state - status: TcpStatus, - // The current list of segments (ordered by rel_seq) + // The current list of segments that this peer is about to sent (ordered by rel_seq) segments: VecDeque, addr: Ipv4Addr, port: u16, - - window_size: u32, - used_window_size: u32, } impl TcpPeer { @@ -336,22 +271,15 @@ impl TcpPeer { struct TcpStream { pub client: TcpPeer, pub server: TcpPeer, - pub status: TcpStatus, - // from packet.option or timeval passed by api argument. Used to check timeout. - // the free of session is NOT decided by this value. The api user should decide it. - pub last_seen_ts: Duration, + in_connection: bool, } + #[derive(Debug)] pub struct TcpConnection { stream: TcpStream, - packets_sent_by_client: Vec>, - packets_sent_by_server: Vec>, max_packets: usize, - max_warning_packets: usize, - - pub timeout: Duration, // todo: 当前使用默认值 } impl TcpPeer { @@ -360,13 +288,9 @@ impl TcpPeer { isn: Wrapping(0), ian: Wrapping(0), next_rel_seq: Wrapping(0), - last_rel_ack: Wrapping(0), - status: TcpStatus::Closed, segments: VecDeque::new(), addr: *addr, port, - window_size: 0, - used_window_size: 0, } } } @@ -374,158 +298,64 @@ impl TcpPeer { impl TcpStream { pub fn new(packet: &TcpPacket) -> Self { TcpStream { - client: TcpPeer::new(&packet.src_ip, packet.tcp_header.source_port), - server: TcpPeer::new(&packet.dst_ip, packet.tcp_header.dest_port), - status: TcpStatus::Closed, - last_seen_ts: packet.get_timestamp(), + client: TcpPeer::new(&packet.src_ip, packet.src_port), + server: TcpPeer::new(&packet.dst_ip, packet.dst_port), + in_connection: true, } } - fn handle_new_connection( - &mut self, - tcp: TcpPacket, - to_server: bool, - ) -> Result>, TcpSessionErr> { + fn handle_synsent(&mut self, tcp: TcpPacket) { let seq = Wrapping(tcp.get_sequence()); - let ack = Wrapping(tcp.get_acknowledgement()); - let (src, dst) = if to_server { - (&mut self.client, &mut self.server) - } else { - (&mut self.server, &mut self.client) - }; + self.client.isn = seq; + self.client.next_rel_seq = Wrapping(1); + self.server.ian = seq; + + if !tcp.payload().is_empty() { + println!("Data in handshake SYN"); + // https://stackoverflow.com/questions/37994131/send-tcp-syn-packet-with-payload + // it is possible to have data in SYN, just queue it(the src window size is 0 currently) + let segment = TcpSegment { + rel_seq: Wrapping(1), // just assume client has sent a ACK, and turn to ESTABLISHED. + payload: tcp.payload().to_vec(), + tcp_header: tcp.raw_packet.header(), + raw_packet: tcp.raw_packet, + rel_ack: Wrapping(1), + }; + queue_segment(&mut self.client, segment); + } + } + + fn handle_synrcv(&mut self, tcp: TcpPacket) -> Result<(), TcpSegmentDescription> { + // Server -- SYN+ACK --> Client + let (src, dst) = (&mut self.server, &mut self.client); + let seq = Wrapping(tcp.get_sequence()); + let ack = Wrapping(tcp.get_acknowledgement()); - match src.status { - // Client -- SYN --> Server - TcpStatus::Closed => { - if tcp.has_flag(TcpFlags::RST) { - // TODO check if destination.segments must be removed - // client sent a RST, this is expected - return Ok(None); - } - if !tcp.has_flag(TcpFlags::SYN) { - // not a SYN - usually happens at start of pcap if missed SYN - println!("First packet of a TCP stream is not a SYN"); - return Err(TcpSessionErr::WrongFlags); - } - if tcp.has_flag(TcpFlags::ACK) { - println!("First packet is SYN+ACK"); - return Err(TcpSessionErr::WrongFlags); - } - src.isn = seq; - src.next_rel_seq = Wrapping(1); - dst.ian = seq; - dst.window_size = cal_total_window_size(&tcp); // the server sliding window size(client receiving window size) - self.status = TcpStatus::SynSent; - src.status = TcpStatus::SynSent; - dst.status = TcpStatus::Listen; - - if !tcp.payload().is_empty() { - println!("Data in handshake SYN"); - // https://stackoverflow.com/questions/37994131/send-tcp-syn-packet-with-payload - // it is possible to have data in SYN, just queue it(the src window size is 0 currently) - let segment = TcpSegment { - rel_seq: seq - src.isn, - rel_ack: ack - dst.isn, - payload: tcp.payload().to_vec(), - tcp_header: tcp.tcp_header, - raw_packet: tcp.raw_packet, - }; - queue_segment(src, segment); - } - } - // Server -- SYN+ACK --> Client - TcpStatus::Listen => { - if !tcp.has_flag(TcpFlags::SYN) && !tcp.has_flag(TcpFlags::ACK) { - println!("Not a SYN + ACK"); - return Err(TcpSessionErr::WrongFlags); - } - // if we had data in SYN, add its length - let next_rel_seq = if dst.segments.is_empty() { - Wrapping(1) - } else { // - Wrapping(1) + Wrapping(dst.segments[0].payload.len() as u32) - }; - if ack != dst.isn + next_rel_seq { - println!("NEW/SYN-ACK: ack number is wrong"); - return Err(TcpSessionErr::HandshakeFailed); - } - src.isn = seq; - src.next_rel_seq = Wrapping(1); - dst.ian = seq; - dst.last_rel_ack = Wrapping(1); - dst.window_size = cal_total_window_size(&tcp); // the client sliding window size(server receiving window size) + if !tcp.has_flag(TcpFlags::SYN) || !tcp.has_flag(TcpFlags::ACK) { + return Err(TcpSegmentDescription::HandshakeFail("Not a SYN + ACK".to_string())); + } - src.status = TcpStatus::SynRcv; - self.status = TcpStatus::SynRcv; + // if we had data in SYN, add its length + let next_rel_seq = if dst.segments.is_empty() { + Wrapping(1) + } else { // + Wrapping(1) + Wrapping(dst.segments[0].payload.len() as u32) + }; + if ack != dst.isn + next_rel_seq { + return Err(TcpSegmentDescription::HandshakeFail("ack number is wrong".to_string())); + } - // do not push data if we had some in SYN, it will be done after handshake succeeds - } - // Client -- ACK --> Server - TcpStatus::SynSent => { - if !tcp.has_flag(TcpFlags::ACK) { - if tcp.has_flag(TcpFlags::SYN) { - // can be a SYN resend - if seq == src.isn && ack.0 == 0 { - println!("SYN resend - ignoring"); - return Ok(None); - } - // can be a disordered handshake (receive S after SA) - if seq + Wrapping(1) == dst.ian { - println!("Likely received SA before S - ignoring"); - return Ok(None); - } - } - println!("Not an ACK"); - } + src.isn = seq; + src.next_rel_seq = Wrapping(1); + dst.ian = seq; - if ack != dst.isn + Wrapping(1) { - println!("NEW/ACK: ack number is wrong"); - return Err(TcpSessionErr::HandshakeFailed); - } - src.status = TcpStatus::Established; - dst.status = TcpStatus::Established; - dst.last_rel_ack = Wrapping(1); - self.status = TcpStatus::Established; - // do we have data ? - if !tcp.payload().is_empty() { - if dst.used_window_size + tcp.payload().len() as u32 > dst.window_size { - println!("NEW/ACK: received data but window is full"); - return Err(TcpSessionErr::SidePeerWindowFull); - } - dst.used_window_size += tcp.payload().len() as u32; - - let segment = TcpSegment { - rel_seq: seq - src.isn, - rel_ack: ack - dst.isn, - payload: tcp.payload().to_vec(), // XXX data cloned here - tcp_header: tcp.tcp_header, - raw_packet: tcp.raw_packet, - }; - queue_segment(src, segment); - } - } - TcpStatus::SynRcv => { - // we received something while in SYN_RCV state - we should only have sent ACK - // this could be a SYN+ACK retransmit - if tcp.has_flag(TcpFlags::SYN) && tcp.has_flag(TcpFlags::ACK) { - // XXX compare SEQ numbers? - // ignore - return Ok(None); - } - println!("Received unexpected data in SYN_RCV state"); - return Err(TcpSessionErr::WrongFlags); - } - _ => unreachable!(), - } - Ok(None) + Ok(()) } - fn handle_established_connection( - &mut self, + fn update_after_handshake(&mut self, tcp: TcpPacket, - to_server: bool, - ) -> Result>, TcpSessionErr> { + to_server: bool) -> (Option>, TcpSegmentDescription) { let (origin, destination) = if to_server { (&mut self.client, &mut self.server) } else { @@ -534,9 +364,9 @@ impl TcpStream { let rel_seq = Wrapping(tcp.get_sequence()) - origin.isn; let rel_ack = Wrapping(tcp.get_acknowledgement()) - destination.isn; - let has_ack = tcp.has_flag(TcpFlags::ACK); // get it before borrowing tcp + let is_fin = tcp.has_flag(TcpFlags::FIN) || tcp.has_flag(TcpFlags::RST); // before borrowing tcp - println!("EST: payload len={}", tcp.payload().len()); + println!("update_after_handshake: payload len={}", tcp.payload().len()); println!( " Tcp rel seq {} ack {} next seq {}", rel_seq, @@ -544,174 +374,35 @@ impl TcpStream { origin.next_rel_seq ); - if !tcp.has_flag(TcpFlags::ACK) && tcp.get_acknowledgement() != 0 { - println!( - "Established state packet without ACK (broken TCP implementation or attack)", - ); - // ignore segment - return Err(TcpSessionErr::WrongFlags); - } - if destination.used_window_size + tcp.payload().len() as u32 > destination.window_size { - println!("EST: received data but window is full"); - return Err(TcpSessionErr::SidePeerWindowFull); - } - destination.used_window_size += tcp.payload().len() as u32; - let segment = TcpSegment { rel_seq, rel_ack, payload: tcp.payload().to_vec(), // XXX data cloned here - tcp_header: tcp.tcp_header, + tcp_header: tcp.raw_packet.header(), raw_packet: tcp.raw_packet, }; queue_segment(origin, segment); - - // if there is a ACK, check & send segments on the *other* side - let ret = if has_ack { - send_peer_segments(destination, rel_ack)? - } else { - None - }; - - println!( - " PEER EST rel next seq {} last_ack {}", - destination.next_rel_seq, - destination.last_rel_ack, - ); - - Ok(ret) - } - - fn handle_closing_connection( - &mut self, - tcp: TcpPacket, - to_server: bool, - ) -> Result>, TcpSessionErr> { - let (origin, destination) = if to_server { - (&mut self.client, &mut self.server) - } else { - (&mut self.server, &mut self.client) - }; - - let rel_seq = Wrapping(tcp.get_sequence()) - origin.isn; - let rel_ack = Wrapping(tcp.get_acknowledgement()) - destination.isn; - let has_ack = tcp.has_flag(TcpFlags::ACK); - let has_fin = tcp.has_flag(TcpFlags::FIN); - - let ret = if has_ack { - println!("ACKing segments up to {}", rel_ack); - send_peer_segments(destination, rel_ack)? - } else { - if tcp.get_acknowledgement() != 0 { - println!( - "EST/ packet without ACK (broken TCP implementation or attack)", - ); - // ignore segment - return Err(TcpSessionErr::WrongFlags); - } - None - }; - if tcp.has_flag(TcpFlags::RST) { - // if we get a RST, check the sequence number and remove matching segments - // todo: 这块逻辑对吗 - destination.segments.retain(|s| s.rel_ack != rel_seq); - println!( - "RST: {} remaining (undelivered) segments DESTINATION after removal", - destination.segments.len() - ); - origin.status = TcpStatus::Closed; // XXX except if ACK ? - return Ok(ret); - } - let mut ret = Ok(ret); - if destination.used_window_size + tcp.payload().len() as u32 > destination.window_size { - println!("EST: received data but window is full"); - ret = Err(TcpSessionErr::SidePeerWindowFull); - } else { - destination.used_window_size += tcp.payload().len() as u32; - // queue segment (even if FIN, to get correct seq numbers) - let rel_seq = Wrapping(tcp.get_sequence()) - origin.isn; - let rel_ack = Wrapping(tcp.get_acknowledgement()) - destination.isn; - let segment = TcpSegment { - rel_seq, - rel_ack, - payload: tcp.payload().to_vec(), // XXX data cloned here - tcp_header: tcp.tcp_header, - raw_packet: tcp.raw_packet, - }; - queue_segment(origin, segment); + if is_fin { + // fin packet can also have payload, so we queued it first. Refer to: + // https://stackoverflow.com/questions/8702646/can-a-tcp-packet-with-the-fin-flag-also-have-data + let sent_pkt = flush_peer_segments(origin); + return (Some(sent_pkt), TcpSegmentDescription::FinTrigger); } + // todo: closed connection restart - match origin.status { - TcpStatus::Established => { - // we know there is a FIN (tested in TcpConnection::update) - origin.status = TcpStatus::FinWait1; - destination.status = TcpStatus::CloseWait; // we are not sure it was received - } - TcpStatus::CloseWait => { - if !has_fin { - // if only an ACK, do nothing and stay in CloseWait status - if has_ack { - // println!("destination status: {:?}", destination.status); - if destination.status == TcpStatus::FinWait1 { - destination.status = TcpStatus::FinWait2; - } - } else { - println!("Origin should have sent a FIN and/or ACK"); - ret = Err(TcpSessionErr::WrongFlags); - } - } else { - origin.status = TcpStatus::LastAck; - // println!("destination status: {:?}", destination.status); - if has_ack || destination.status == TcpStatus::FinWait2 { - destination.status = TcpStatus::TimeWait; - } else { - destination.status = TcpStatus::Closing; - } - } - } - TcpStatus::TimeWait => { - // only an ACK should be sent (XXX nothing else, maybe PSH) - if has_ack { - // this is the end! - origin.status = TcpStatus::Closed; - destination.status = TcpStatus::Closed; - } - } - _ => { - println!( - "Unhandled closing transition: origin host {} status {:?}", - origin.addr, origin.status - ); - println!( - " dest host {} status {:?}", - destination.addr, destination.status - ); - } + let sent_pkt = send_peer_segments(origin); + if origin.segments.len() > DEFAULT_MAX_PACKETS { + let sent_pkt = flush_peer_segments(origin); + return (Some(sent_pkt), TcpSegmentDescription::TooManyPacket); } - println!( - "TCP connection closing, {} remaining (undelivered) segments", - origin.segments.len() - ); - // println - for (n, s) in origin.segments.iter().enumerate() { - println!( - " s[{}]: seq={} len={}", - n, - s.rel_seq.0, - s.payload.len(), - ); + if sent_pkt.is_err() { + return (None, sent_pkt.unwrap_err()); + } else { + return (Some(sent_pkt.unwrap()), TcpSegmentDescription::Normal); } - - return ret; } - - // force expiration (for ex after timeout) of this stream - fn expire(&mut self) { - self.client.status = TcpStatus::Closed; - self.server.status = TcpStatus::Closed; - } -} // TcpStream +} fn queue_segment(peer: &mut TcpPeer, segment: TcpSegment) { if segment.payload.is_empty() { @@ -729,84 +420,106 @@ fn queue_segment(peer: &mut TcpPeer, segment: TcpSegment) { peer.insert_sorted(segment); } -fn send_peer_segments(peer: &mut TcpPeer, rel_ack: Wrapping) -> Result>, TcpSessionErr> { - println!( - "Trying to send segments for {}:{} up to {} (last ack: {})", - peer.addr, - peer.port, - rel_ack, - peer.last_rel_ack - ); - if rel_ack == peer.last_rel_ack { - println!("re-acking last data, doing nothing"); - return Ok(None); - } +// let the peer send segments in its queue, update ack numbers, and pop segments that were sent +fn send_peer_segments(peer: &mut TcpPeer) -> Result, TcpSegmentDescription> { if peer.segments.is_empty() { - return Ok(None); - } - - // is ACK acceptable? - if rel_ack < peer.last_rel_ack { - println!("ACK request for already ACKed data (ack < last_ack)"); - return Err(TcpSessionErr::UnexpectedAckNumber); + println!("No segment to send"); + return Err(TcpSegmentDescription::NoSegment); } - - // check consistency of segment ACK numbers + order and/or missing fragments and/or overlap - - let mut acked = Vec::new(); - + let mut ret = Vec::new(); + let mut description = TcpSegmentDescription::Normal; while !peer.segments.is_empty() { let segment = &peer.segments[0]; - if rel_ack <= segment.rel_seq { - // if packet is in the past (strictly less), we don't care + println!("send segment, payload: {:?}", segment.payload); + + if segment.rel_seq > peer.next_rel_seq { // there is a gap + println!("Gap detected"); + description = TcpSegmentDescription::Unordered; break; } + if segment.rel_seq < peer.next_rel_seq { // caused by flush_peer_segments, or duplicate pkt, omit old segments + // todo: 感觉这里最好区别一下是不是因为flush_peer_segments导致的,虽然只影响错误码 + // 甚至会是over lap 导致的 + println!("Dropping segment"); + if (segment.rel_seq + Wrapping(segment.payload.len() as u32)) > peer.next_rel_seq { + println!("Segment overlaps next, payload before: {:?}", segment.payload); + let mut segment = peer.segments.pop_front().unwrap(); + let overlap_offset = (peer.next_rel_seq - segment.rel_seq).0; + segment.payload = segment.payload.split_off(overlap_offset as usize); + println!("Segment overlaps next, payload after: {:?}", segment.payload); + segment.rel_seq = peer.next_rel_seq; + peer.segments.push_front(segment); + assert!(!peer.segments.is_empty()); + } else if (segment.rel_seq + Wrapping(segment.payload.len() as u32)) == peer.next_rel_seq { + println!("Segment ends at next"); + peer.segments.pop_front(); + description = TcpSegmentDescription::DuplicateSeq; + } else { + peer.segments.pop_front(); + description = TcpSegmentDescription::OldPacket; + println!("Segment ends before next"); + } + continue; + } + // safety: segments is just tested above let mut segment = peer.segments.pop_front().unwrap(); - if rel_ack < segment.rel_seq + Wrapping(segment.payload.len() as u32) { - println!("ACK for part of buffer"); - // split data and insert new dummy segment - let acked_len = (rel_ack - segment.rel_seq).0 as usize; - let new_segment = segment.split_off(acked_len); - println!( - "insert new segment from {} len {}", - new_segment.rel_ack, - new_segment.payload.len() - ); - peer.insert_sorted(new_segment); - } - - handle_overlap_linux(peer, &mut segment); + remove_overlapped(peer, &mut segment); adjust_seq_numbers(peer, &segment); + println!("Sending segment, payload: {:?}", segment.payload); + segment.raw_packet.replace_payload(segment.payload.clone()); + ret.push(segment); + } + println!("ret len: {}", ret.len()); + if ret.len() == 0 { + return Err(description); + } + + Ok(ret) +} + +fn flush_peer_segments(peer: &mut TcpPeer) -> Vec { + // 最终预期: + // 1. 队列全清空 + // 2. next seq 正常调整 + + // // 之后呢? + // 标记该Session为满释放异常,并声明一个新的变量,为“上次flush的时候,最大的rel seq” + // 之后,如果有新的segment进来,那么就判断,如果rel seq < 上次flush的最大rel seq,那么直接丢弃 + // 否则,就正常处理,放入队列中。 + + let mut ret = Vec::new(); + while !peer.segments.is_empty() { + // safety: segments is just tested above + let mut segment = peer.segments.pop_front().unwrap(); - if !segment.payload.is_empty() { - acked.push(segment); + remove_overlapped(peer, &mut segment); + if peer.segments.len() == 0 { // the last one has the biggest rel seq + peer.next_rel_seq = segment.rel_seq + Wrapping(segment.payload.len() as u32); } - } - if peer.next_rel_seq != rel_ack { - // missed segments, or maybe received FIN ? - println!( - "TCP ACKed unseen segment next_seq {} != ack {} (Missed segments?)", - peer.next_rel_seq, rel_ack - ); - // TODO 这个正确吗?如果因为这个,把整个老segment 都删了肯定不对,具体怎么处理错误看需求吧 - // return Err(TcpSessionErr::UnexpectedAckNumber); + ret.push(segment); } - peer.last_rel_ack = rel_ack; - Ok(Some(acked)) + ret } -const FIRST_WINS: bool = false; - -// handle overlapping segments, using a linux-like policy -// Linux favors an original segment, EXCEPT when the subsequent begins before the original, -//or the subsequent segment begins the same and ends after the original segment. -#[allow(dead_code)] -fn handle_overlap_linux(peer: &mut TcpPeer, segment: &mut TcpSegment) { + // 情况1: [1,2,3] [4,5,6] + // [3, 4,5] + // 2: [1,2,3] + // [1,2,3] + // 3: [4,5,6] + // [4,5,6] + // 以上三种均为duplicate,直接丢弃 + // 4: [1,2,3] + // [1,2,3,4] + // 保留[1,2,3] [4] + // 5: [2,3] [6,7] + // [1,2,3,4,5,6] + // 保留[1,2,3,4,5,6] [7] +fn remove_overlapped(peer: &mut TcpPeer, segment: &mut TcpSegment) { // loop while segment has overlap while let Some(next) = peer.segments.front() { if let Some(overlap_offset) = segment.overlap_offset(next) { @@ -858,181 +571,75 @@ fn adjust_seq_numbers(origin: &mut TcpPeer, segment: &TcpSegment) { } impl TcpConnection { - pub(crate) fn try_new(packet: &RawPacket) -> Result { + pub(crate) fn try_new(packet: &RawPacket) -> Result { let simple_packet = raw_packet_convert_to_my_packet(packet)?; Self::_try_new(simple_packet) } - pub(crate) fn update(&mut self, packet: &RawPacket) -> Result { - let simple_packet = raw_packet_convert_to_my_packet(packet)?; - self._update(simple_packet) - } - - pub(crate) fn iter(&self, sent_by_client: bool) -> TcpIterator { - let target = { - if sent_by_client { - &self.packets_sent_by_client - } else { - &self.packets_sent_by_server - } - }; - - let mut ret: Vec<(u32, &CopiedRawPacket)> = Vec::new(); - for (index, packet_vec) in target.iter().enumerate() { - for packet in packet_vec.iter() { - ret.push((index as u32, &packet)); - } - } - - TcpIterator { - segments: ret.into_iter(), + pub(crate) fn update(&mut self, packet: &RawPacket) -> (Option>, TcpSegmentDescription) { + let simple_packet = raw_packet_convert_to_my_packet(packet); + if let Err(e) = simple_packet { + return (None, e); } + self._update(simple_packet.unwrap()) } - - fn _try_new(packet: TcpPacket) -> Result { + + fn _try_new(packet: TcpPacket) -> Result { let mut connection = TcpConnection { stream: TcpStream::new(&packet), - packets_sent_by_client: Vec::new(), - packets_sent_by_server: Vec::new(), - timeout: DEFAULT_TIMEOUT, max_packets: DEFAULT_MAX_PACKETS, - max_warning_packets: DEFAULT_MAX_PACKETS >> 1, }; if !packet.has_flag(TcpFlags::SYN) { - return Err(TcpSessionErr::WrongFlags); + return Err(TcpSegmentDescription::HandshakeFail("Not a SYN".to_string())); } - - connection._update(packet)?; - if connection.stream.client.status == TcpStatus::Closed || connection.stream.server.status == TcpStatus::Closed { - return Err(TcpSessionErr::NewConnectionFailed); + if packet.has_flag(TcpFlags::ACK) { + println!("First packet is SYN+ACK"); + return Err(TcpSegmentDescription::HandshakeFail("First packet is SYN+ACK".to_string())); + } + if packet.has_flag(TcpFlags::RST) { + return Err(TcpSegmentDescription::HandshakeFail("First packet is RST".to_string())); } + if packet.has_flag(TcpFlags::FIN) { + return Err(TcpSegmentDescription::HandshakeFail("First packet is FIN".to_string())); + } + + connection.stream.handle_synsent(packet); Ok(connection) } - fn _update(&mut self, tcp: TcpPacket) -> Result { + fn _update(&mut self, tcp: TcpPacket) -> (Option>, TcpSegmentDescription) { let stream = &mut self.stream; - println!("stream state: {:?}", stream.status); - - let client_status_before = stream.client.status.clone(); - let server_status_before = stream.server.status.clone(); - let mut ok_ret = TcpSessionOk::OTHER; - - // check time delay with previous packet before updating - let packet_ts = tcp.get_timestamp(); - if stream.last_seen_ts > packet_ts { - println!("packet received in past"); - ok_ret |= TcpSessionOk::PACKET_IN_PAST; - } else if packet_ts - stream.last_seen_ts > self.timeout { - println!("TCP stream received packet after timeout"); - stream.expire(); - return Err(TcpSessionErr::ExpiredSession); - } - stream.last_seen_ts = packet_ts; + // get origin and destination let to_server = tcp.dst_ip == stream.server.addr && - tcp.tcp_header.dest_port == stream.server.port; + tcp.dst_port == stream.server.port; println!("to_server: {}", to_server); - let (origin, _destination) = if to_server { - (&mut stream.client, &mut stream.server) - } else { - (&mut stream.server, &mut stream.client) - }; - println!("origin: {}:{} status {:?}", - origin.addr, - origin.port, - origin.status - ); - let sent_packet = - match origin.status { - TcpStatus::Closed | TcpStatus::Listen | TcpStatus::SynSent | TcpStatus::SynRcv => { - stream.handle_new_connection(tcp, to_server) - } - TcpStatus::Established => { - // check for close request - if tcp.has_flag(TcpFlags::FIN) || tcp.has_flag(TcpFlags::RST) { - stream.handle_closing_connection(tcp, to_server) - } else { - stream.handle_established_connection(tcp, to_server) - } - } - _ => stream.handle_closing_connection(tcp, to_server), - }?; - - if let Some(sent_packet) = sent_packet { - ok_ret = ok_ret | TcpSessionOk::ACK_SEGMENT; - let send_queue = if to_server { // ack packet, so the previous packet if from the other side. - &mut self.packets_sent_by_server - } else { - &mut self.packets_sent_by_client - }; - if send_queue.len() >= self.max_packets { - ok_ret |= TcpSessionOk::TOO_MANY_PACKETS; - } else { - if send_queue.len() >= self.max_warning_packets { - ok_ret |= TcpSessionOk::TOO_MANY_PACKET_WARNING; - } - send_queue.push(sent_packet.into_iter().map(|s| s.raw_packet).collect()); + if self.stream.in_connection { + let ret = self.stream.handle_synrcv(tcp); + if let Err(e) = ret { + return (None, e); } - } - - if client_status_before != stream.client.status || server_status_before != stream.server.status { - println!("status changed: {:?} -> {:?} / {:?} -> {:?}", - client_status_before, stream.client.status, server_status_before, stream.server.status - ); - if stream.client.status == TcpStatus::Established { - ok_ret = ok_ret | TcpSessionOk::ESTABLISHED; - } else if stream.client.status == TcpStatus::Closed || stream.server.status == TcpStatus::Closed { - ok_ret |= TcpSessionOk::CLOSED; - } else if client_status_before == TcpStatus::Established { - ok_ret = TcpSessionOk::CLOSING; - } - } - Ok(ok_ret) - } - // todo: refresh (删除所有的数据包,但是保留状态) - -} - -impl fmt::Debug for TcpPeer { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "Peer: {}:{}", self.addr, self.port)?; - writeln!(f, " status: {:?}", self.status)?; - writeln!(f, " isn: 0x{:x} ian: 0x{:x}", self.isn, self.ian)?; - writeln!(f, " next_rel_seq: {}", self.next_rel_seq)?; - writeln!(f, " last_rel_ack: {}", self.last_rel_ack)?; - writeln!(f, " #segments: {}", self.segments.len())?; - for (n, s) in self.segments.iter().enumerate() { - writeln!( - f, - " s[{}]: rel_seq={} len={}", - n, - s.rel_seq, - s.payload.len(), - )?; + self.stream.in_connection = false; + return (None, TcpSegmentDescription::SynAckOk); } - Ok(()) - } -} + let (segments, ret) = self.stream.update_after_handshake(tcp, to_server); -fn cal_total_window_size(tcp: &TcpPacket) -> u32 { - let mut scale = 0; - if let Some(options) = &tcp.tcp_header.options { - for option in options { - if let TcpOption::WSCALE{length, shift_count} = option { - scale = *shift_count; - } + if segments.is_none() { + return (None, ret); } - } - - let total_window_size = tcp.tcp_header.window as u32; - total_window_size << scale + let ret_packet = segments.unwrap().into_iter(). + map(|segment| {segment.raw_packet}).collect(); + + return (Some(ret_packet), ret); + } } + #[cfg(test)] mod tests { use std::vec; @@ -1105,238 +712,274 @@ mod tests { } } + const CLIENT: PeerInTest = PeerInTest { + addr: Ipv4Addr::new(192, 168, 1, 1), + port: 1234, + role: PeerRole::Client, + }; + const SERVER: PeerInTest = PeerInTest { + addr: Ipv4Addr::new(192, 168, 1, 2), + port: 80, + role: PeerRole::Server, + }; + #[test] - fn a_very_normal_connection() { - let client = PeerInTest { - addr: Ipv4Addr::new(192, 168, 1, 1), - port: 1234, - role: PeerRole::Client, - }; - let server = PeerInTest { - addr: Ipv4Addr::new(192, 168, 1, 2), - port: 80, - role: PeerRole::Server, - }; - let packet_handshake1 = new_raw(&client, &server, 0, 0, false, true, false, false, &[]); + fn single_segment_ping_pong() { + const INIT_SEQ:u32 = 12345; + let packet_handshake1 = new_raw(&CLIENT, &SERVER, INIT_SEQ + 0, 0, false, true, false, false, &[]); let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); - assert!(connection.stream.client.status == TcpStatus::SynSent); - assert!(connection.stream.server.status == TcpStatus::Listen); - - let packet_handshake2 = new_raw(&server, &client, 0, 1, true, true, false, false, &[]); - let ret = connection.update(&packet_handshake2).unwrap(); - assert!(connection.stream.client.status == TcpStatus::SynSent); - assert!(connection.stream.server.status == TcpStatus::SynRcv); - assert!(ret == TcpSessionOk::OTHER); - - let packet_handshake3 = new_raw(&client, &server, 1, 1, true, false, false, false, &[]); - let ret = connection.update(&packet_handshake3).unwrap(); - assert!(connection.stream.client.status == TcpStatus::Established); - assert!(connection.stream.server.status == TcpStatus::Established); - assert!(ret == TcpSessionOk::ESTABLISHED); - - let packet_established_from_client = new_raw(&client, &server, 1, 1, true, false, false, false, &[1, 2, 3]); - let ret = connection.update(&packet_established_from_client).unwrap(); - assert!(ret == TcpSessionOk::OTHER); - assert!(connection.stream.client.segments.len() == 1); - - let packet_established_server_response = new_raw(&server, &client, 1, 4, true, false, false, false, &[]); - let ret = connection.update(&packet_established_server_response).unwrap(); - assert!(connection.stream.client.segments.len() == 0); - assert!(ret == TcpSessionOk::ACK_SEGMENT); - let seg: Vec<_> = connection.iter(true).collect(); - assert!(seg.len() == 1); - assert!(raw_packet_convert_to_my_packet(&seg[0].1).unwrap().payload == [1, 2, 3]); - - let packet_established_from_server = new_raw(&server, &client, 1, 4, true, false, false, false, &[4]); - let ret = connection.update(&packet_established_from_server).unwrap(); - assert!(connection.stream.server.segments.len() == 1); - assert!(ret == TcpSessionOk::OTHER); - - let packet_established_client_response = new_raw(&client, &server, 4, 2, true, false, false, false, &[]); - let ret = connection.update(&packet_established_client_response).unwrap(); - assert!(ret == TcpSessionOk::ACK_SEGMENT); - let seg: Vec<_> = connection.iter(false).collect(); - assert!(seg.len() == 1); - assert!(raw_packet_convert_to_my_packet(&seg[0].1).unwrap().payload == [4]); - - assert!(connection.stream.client.status == TcpStatus::Established); - assert!(connection.stream.server.status == TcpStatus::Established); - let packet_close_by_client = new_raw(&client, &server, 4, 2, true, false, false, true, &[]); - let ret = connection.update(&packet_close_by_client).unwrap(); - assert!(connection.stream.client.status == TcpStatus::FinWait1); - assert!(connection.stream.server.status == TcpStatus::CloseWait); - assert!(ret == TcpSessionOk::CLOSING); - - let packet_close_response_by_server = new_raw(&server, &client, 2, 5, true, false, false, false, &[]); - let ret = connection.update(&packet_close_response_by_server).unwrap(); - assert!(connection.stream.client.status == TcpStatus::FinWait2); - assert!(connection.stream.server.status == TcpStatus::CloseWait); - assert!(ret == TcpSessionOk::OTHER); - - let packet_close_by_server = new_raw(&server, &client, 2, 5, true, false, false, true, &[]); - let ret = connection.update(&packet_close_by_server).unwrap(); - assert!(connection.stream.client.status == TcpStatus::TimeWait); - assert!(connection.stream.server.status == TcpStatus::LastAck); - assert!(ret == TcpSessionOk::OTHER); - - let packet_close_response_by_client = new_raw(&client, &server, 5, 3, true, false, false, false, &[]); - let ret = connection.update(&packet_close_response_by_client).unwrap(); - assert!(connection.stream.client.status == TcpStatus::Closed); - assert!(connection.stream.server.status == TcpStatus::Closed); - assert!(ret == TcpSessionOk::CLOSED); + + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, INIT_SEQ+1, true, true, false, false, &[]); + let ret = connection.update(&packet_handshake2); + assert!(ret.1 == TcpSegmentDescription::SynAckOk); + + let packet_established_from_client = new_raw(&CLIENT, &SERVER, INIT_SEQ+1, 1, true, false, false, false, &[1, 2, 3]); + let ret = connection.update(&packet_established_from_client); + println!("ret: {:?}", ret); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.unwrap()[0].payload() == &[1, 2, 3]); + + let packet_established_from_server = new_raw(&SERVER, &CLIENT, 1, INIT_SEQ+4, true, false, false, false, &[4, 5, 6]); + let ret = connection.update(&packet_established_from_server); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.unwrap()[0].payload() == &[4, 5, 6]); + + let packet_established_from_client2 = new_raw(&CLIENT, &SERVER, INIT_SEQ+4, 4, true, false, false, false, &[7, 8, 9]); + let ret = connection.update(&packet_established_from_client2); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.unwrap()[0].payload() == &[7, 8, 9]); } #[test] - fn several_ordered_segments_in_one_ack() { - let client = PeerInTest { - addr: Ipv4Addr::new(192, 168, 1, 1), - port: 1234, - role: PeerRole::Client, - }; - let server = PeerInTest { - addr: Ipv4Addr::new(192, 168, 1, 2), - port: 80, - role: PeerRole::Server, - }; + fn several_ordered_consecutive_segments() { + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + connection.update(&packet_handshake2); + + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1, 2, 3]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.unwrap()[0].payload() == &[1, 2, 3]); + let packet_established_from_client2 = new_raw(&CLIENT, &SERVER, 4, 1, true, false, false, false, &[4, 5, 6]); + let ret = connection.update(&packet_established_from_client2); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.unwrap()[0].payload() == &[4, 5, 6]); + let packet_established_from_client3 = new_raw(&CLIENT, &SERVER, 7, 1, true, false, false, false, &[7, 8, 9]); + let ret = connection.update(&packet_established_from_client3); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.unwrap()[0].payload() == &[7, 8, 9]); + } - // standard handshake - let packet_handshake1 = new_raw(&client, &server, 0, 0, false, true, false, false, &[]); + #[test] + fn several_unordered_consecutive_segments() { + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); - let packet_handshake2 = new_raw(&server, &client, 0, 1, true, true, false, false, &[]); - connection.update(&packet_handshake2).unwrap(); - let packet_handshake3 = new_raw(&client, &server, 1, 1, true, false, false, false, &[]); - connection.update(&packet_handshake3).unwrap(); - - // send 3 segments from client - let packet_established_from_client1 = new_raw(&client, &server, 1, 1, true, false, false, false, &[1, 2, 3]); - let ret = connection.update(&packet_established_from_client1).unwrap(); - assert!(ret == TcpSessionOk::OTHER); - let packet_established_from_client2 = new_raw(&client, &server, 4, 1, true, false, false, false, &[4, 5, 6]); - let ret = connection.update(&packet_established_from_client2).unwrap(); - assert!(ret == TcpSessionOk::OTHER); - let packet_established_from_client3 = new_raw(&client, &server, 7, 1, true, false, false, false, &[7, 8, 9]); - let ret = connection.update(&packet_established_from_client3).unwrap(); - assert!(ret == TcpSessionOk::OTHER); - - // server ack - let packet_established_server_responce = new_raw(&server, &client, 1, 10, true, false, false, false, &[]); - let ret = connection.update(&packet_established_server_responce).unwrap(); - assert!(ret == TcpSessionOk::ACK_SEGMENT); - let seg: Vec<_> = connection.iter(true).collect(); - assert!(seg.len() == 3); - assert!(seg[0].0 == 0); - assert!(seg[1].0 == 0); - assert!(seg[2].0 == 0); - assert!(raw_packet_convert_to_my_packet(&seg[0].1).unwrap().payload == [1, 2, 3]); - assert!(raw_packet_convert_to_my_packet(&seg[1].1).unwrap().payload == [4, 5, 6]); - assert!(raw_packet_convert_to_my_packet(&seg[2].1).unwrap().payload == [7, 8, 9]); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + connection.update(&packet_handshake2); + + let packet_established_from_client2 = new_raw(&CLIENT, &SERVER, 4, 1, true, false, false, false, &[4, 5, 6]); + let ret = connection.update(&packet_established_from_client2); + assert!(ret.1 == TcpSegmentDescription::Unordered); + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1, 2, 3]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.as_ref().unwrap()[0].payload() == &[1, 2, 3]); + assert!(ret.0.unwrap()[1].payload() == &[4, 5, 6]); + let packet_established_from_client3 = new_raw(&CLIENT, &SERVER, 7, 1, true, false, false, false, &[7, 8, 9]); + let ret = connection.update(&packet_established_from_client3); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.unwrap()[0].payload() == &[7, 8, 9]); } #[test] - fn several_unordered_segments_in_one_ack() { - let client = PeerInTest { - addr: Ipv4Addr::new(192, 168, 1, 1), - port: 1234, - role: PeerRole::Client, - }; - let server = PeerInTest { - addr: Ipv4Addr::new(192, 168, 1, 2), - port: 80, - role: PeerRole::Server, - }; + fn several_unordered_inconsecutive_segments() { + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + connection.update(&packet_handshake2); + + let packet_established_from_client2 = new_raw(&CLIENT, &SERVER, 4, 1, true, false, false, false, &[4, 5, 6]); + let ret = connection.update(&packet_established_from_client2); + println!("ret: {:?}", ret); + assert!(ret.1 == TcpSegmentDescription::Unordered); + let packet_from_server = new_raw(&SERVER, &CLIENT, 1, 1, true, false, false, false, &[11, 12, 13]); + let ret = connection.update(&packet_from_server); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.as_ref().unwrap()[0].payload() == &[11, 12, 13]); + + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1, 2, 3]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.as_ref().unwrap()[0].payload() == &[1, 2, 3]); + assert!(ret.0.unwrap()[1].payload() == &[4, 5, 6]); + let packet_established_from_client3 = new_raw(&CLIENT, &SERVER, 7, 1, true, false, false, false, &[7, 8, 9]); + let ret = connection.update(&packet_established_from_client3); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.unwrap()[0].payload() == &[7, 8, 9]); + } - // standard handshake - let packet_handshake1 = new_raw(&client, &server, 0, 0, false, true, false, false, &[]); + #[test] + fn duplicate_packet() { + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); - let packet_handshake2 = new_raw(&server, &client, 0, 1, true, true, false, false, &[]); - connection.update(&packet_handshake2).unwrap(); - let packet_handshake3 = new_raw(&client, &server, 1, 1, true, false, false, false, &[]); - connection.update(&packet_handshake3).unwrap(); - - // send 3 segments from client - let packet_established_from_client1 = new_raw(&client, &server, 1, 1, true, false, false, false, &[1, 2, 3]); - let packet_established_from_client2 = new_raw(&client, &server, 4, 1, true, false, false, false, &[4, 5, 6]); - let packet_established_from_client3 = new_raw(&client, &server, 7, 1, true, false, false, false, &[7, 8, 9]); - connection.update(&packet_established_from_client3).unwrap(); // unwrap: check not error - connection.update(&packet_established_from_client1).unwrap(); - connection.update(&packet_established_from_client2).unwrap(); - - // server ack - let packet_established_server_responce = new_raw(&server, &client, 1, 10, true, false, false, false, &[]); - let ret = connection.update(&packet_established_server_responce).unwrap(); - assert!(ret == TcpSessionOk::ACK_SEGMENT); - let seg: Vec<_> = connection.iter(true).collect(); - assert!(seg.len() == 3); - // assert!(ret.as_ref().unwrap()[0].payload.as_slice() == &[1, 2, 3]); - assert!(seg[0].0 == 0); - assert!(seg[1].0 == 0); - assert!(seg[2].0 == 0); - assert!(raw_packet_convert_to_my_packet(&seg[0].1).unwrap().payload == [1, 2, 3]); - assert!(raw_packet_convert_to_my_packet(&seg[1].1).unwrap().payload == [4, 5, 6]); - assert!(raw_packet_convert_to_my_packet(&seg[2].1).unwrap().payload == [7, 8, 9]); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + connection.update(&packet_handshake2); + + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1, 2, 3]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.as_ref().unwrap()[0].payload() == &[1, 2, 3]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::DuplicateSeq); + assert!(ret.0.is_none()); } #[test] - fn wrong_ack_num() { - let client = PeerInTest { - addr: Ipv4Addr::new(192, 168, 1, 1), - port: 1234, - role: PeerRole::Client, - }; - let server = PeerInTest { - addr: Ipv4Addr::new(192, 168, 1, 2), - port: 80, - role: PeerRole::Server, - }; - // standard handshake - let packet_handshake1 = new_raw(&client, &server, 0, 0, false, true, false, false, &[]); + fn too_many_packet() { + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); - let packet_handshake2 = new_raw(&server, &client, 0, 1, true, true, false, false, &[]); - connection.update(&packet_handshake2).unwrap(); - let packet_handshake3 = new_raw(&client, &server, 1, 1, true, false, false, false, &[]); - connection.update(&packet_handshake3).unwrap(); - // send 3 segments from client - let packet_established_from_client1 = new_raw(&client, &server, 1, 1, true, false, false, false, &[1, 2, 3]); - connection.update(&packet_established_from_client1).unwrap(); - let packet_established_from_client2 = new_raw(&client, &server, 4, 1, true, false, false, false, &[4, 5, 6]); - connection.update(&packet_established_from_client2).unwrap(); - let packet_established_from_client3 = new_raw(&client, &server, 7, 1, true, false, false, false, &[7, 8, 9]); - connection.update(&packet_established_from_client3).unwrap(); - - // server ack twice - let packet_established_server_responce = new_raw(&server, &client, 1, 10, true, false, false, false, &[]); - let ret = connection.update(&packet_established_server_responce).unwrap(); - assert!(ret == TcpSessionOk::ACK_SEGMENT); - let ret = connection.update(&packet_established_server_responce).unwrap(); - assert!(ret == TcpSessionOk::OTHER); - - // server ack with wrong ack number(smaller) - let packet_established_from_client4 = new_raw(&client, &server, 10, 1, true, false, false, false, &[10, 11, 12]); - connection.update(&packet_established_from_client4).unwrap(); - let packet_established_server_responce_wrong_ack = new_raw(&server, &client, 1, 7, true, false, false, false, &[]); - let ret = connection.update(&packet_established_server_responce_wrong_ack); - assert!(ret == Err(TcpSessionErr::UnexpectedAckNumber)); - let seg: Vec<_> = connection.iter(true).collect(); - assert!(seg.len() == 3); - - // server ack with ack number(bigger). Bigger ack is Ok. - let packet_established_server_responce_wrong_ack = new_raw(&server, &client, 1, 20, true, false, false, false, &[]); - let ret = connection.update(&packet_established_server_responce_wrong_ack); - assert!(ret == Ok(TcpSessionOk::ACK_SEGMENT)); - let seg: Vec<_> = connection.iter(true).collect(); - assert!(seg.len() == 4); - assert!(seg[0].0 == 0); - assert!(seg[1].0 == 0); - assert!(seg[2].0 == 0); - assert!(seg[3].0 == 1); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + connection.update(&packet_handshake2); + + // let first packet drop , so that the queue will be filled until full + for i in 1..DEFAULT_MAX_PACKETS + 1{ + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 1 + i as u32, 1, true, false, false, false, &[1]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::Unordered); + } + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 1 + DEFAULT_MAX_PACKETS as u32, 1, true, false, false, false, &[1]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::TooManyPacket); + assert!(ret.0.unwrap().len() == DEFAULT_MAX_PACKETS); + + // the first packet come unexpectedly, just throw it away + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::OldPacket); + assert!(ret.0.is_none()); + + // continue to send + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 2 + DEFAULT_MAX_PACKETS as u32, 1, true, false, false, false, &[2,3,4]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.unwrap()[0].payload() == &[2,3,4]); } - // #[test] - // todo: expired session - - // todo: many error flag - // todo: window full - // todo: many packet in queue(warning and discard) + #[test] + fn segment_in_syn() { + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[1,2,3]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 4, true, true, false, false, &[]); + connection.update(&packet_handshake2); - + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 4, 1, true, false, false, false, &[4,5,6]); + let ret = connection.update(&packet_established_from_client); + println!("ret: {:?}", ret); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.as_ref().unwrap()[0].payload() == &[1,2,3]); + assert!(ret.0.as_ref().unwrap()[1].payload() == &[4,5,6]); + } + + #[test] + fn wrong_flag_during_handshake_syn() { + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, false, false, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1); + assert!(connection.is_err()); + assert!(connection.unwrap_err() == TcpSegmentDescription::HandshakeFail("Not a SYN".to_string())); + + + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, true, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1); + assert!(connection.is_err()); + assert!(connection.unwrap_err() == TcpSegmentDescription::HandshakeFail("First packet is RST".to_string())); + } + + #[test] + fn wrong_flag_during_handshake_acksyn() { + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[1,2,3]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 3, true, true, false, false, &[]); // expected ack num is 4 + let ret = connection.update(&packet_handshake2); + assert!(ret.1 == TcpSegmentDescription::HandshakeFail("ack number is wrong".to_string())); + + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 4, false, true, false, false, &[]); // no ack + let ret = connection.update(&packet_handshake2); + assert!(ret.1 == TcpSegmentDescription::HandshakeFail("Not a SYN + ACK".to_string())); + } + + #[test] + fn not_tcp_ip_packet() { + let mut packet = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + packet.encapsulation.pop(); + let mut connection = TcpConnection::try_new(&packet); + assert!(connection.is_err()); + assert!(connection.unwrap_err() == TcpSegmentDescription::NotIp4Tcp); + } + + #[test] + fn fin_with_data() { + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + connection.update(&packet_handshake2); + + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 2, 1, true, false, false, false, &[2,3]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::Unordered); + + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 4, 1, true, false, false, true, &[4,5,6]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::FinTrigger); + assert!(ret.0.as_ref().unwrap()[0].payload() == &[2,3]); + assert!(ret.0.as_ref().unwrap()[1].payload() == &[4,5,6]); + } + + #[test] + fn overlap_partially_sent_before() { + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + connection.update(&packet_handshake2); + + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1,2,3]); + let ret = connection.update(&packet_established_from_client); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.as_ref().unwrap()[0].payload() == &[1,2,3]); + + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 3, 1, true, false, false, false, &[3,4,5,6]); + let ret = connection.update(&packet_established_from_client); + println!("ret: {:?}", ret.0.as_ref().unwrap()[0].payload()); + assert!(ret.1 == TcpSegmentDescription::Normal); + assert!(ret.0.as_ref().unwrap()[0].payload() == &[4,5,6]); // [3] overlap, only send [4,5,6] + } + + // fn overlap_as_old_packet() { + // [1,2,3,4,5,6] + // [2,3,4] + // } + + // fn overlap_change_next_one() { + // [2,3,4] + // [3,4,5] + // [1] + // } + + // fn overlap_del_next_one() { + // [2,3,4, 5] + // [3,4,5] + // [1,2,3] + // } + + // fn overlap_surpass_all() { + // [2] + // [3] + // [1,2,3,4] + // } + + + + + // todo: 回绕 } \ No newline at end of file diff --git a/src/session/tcp_reassembly_with_status_deprecated.rs b/src/session/tcp_reassembly_with_status_deprecated.rs new file mode 100644 index 0000000..12b683e --- /dev/null +++ b/src/session/tcp_reassembly_with_status_deprecated.rs @@ -0,0 +1,1482 @@ +use std::collections::VecDeque; +use std::fmt; +use std::net::{Ipv4Addr}; +use std::num::Wrapping; +use std::vec::IntoIter; +use nom::Err; + +use super::duration::Duration; +use crate::protocol::ipv4::IPv4Header; +use crate::protocol::ipv6::IPv6Header; +use crate::protocol::udp::UdpHeader; +use crate::protocol::ethernet::EthernetFrame; +use crate::protocol::tcp::{TcpHeader, TcpOption, self}; +use crate::protocol::dns::DNS_MESSAGE; +use crate::protocol::http::HTTP_MESSAGE; +use crate::packet::packet::Encapsulation; +use crate::packet::packet::Packet as RawPacket; + +const DEFAULT_TIMEOUT: Duration = Duration{secs:7200, micros:0}; // 120 min timeout, currently do not support option 28 parse +// todo: const TCP_OPTION_TIMESTAMPS: u8 = 28; https://datatracker.ietf.org/doc/rfc5482/ +const DEFAULT_MAX_PACKETS: usize = 2048; + + +#[derive(Debug, Eq, PartialEq)] +pub enum TcpSessionErr { + PacketNotV4Tcp, + WrongFlags, // todo返回当前的会话状态,以及期望的flag, 为此: + // todo: TCP header to flag:u16 + ExpiredSession, + /// the seq number is wrong in the 2nd or 3rd handshake. + HandshakeFailed, + /// new a connection failed, because the packet is not a SYN packet, but the packet is valid + NewConnectionFailed, + /// ack number is even higher than the next expected seq number + UnexpectedAckNumber, + /// The packet itself is valid, but the window size is not enough to hold the packet. Tcp state machine will change as usual, while the packet payload is discarded. + SidePeerWindowFull, +} + +#[allow(non_snake_case)] +pub mod TcpSessionOk { // use as enum + pub const ESTABLISHED: u32 = 0x01; + pub const CLOSING: u32 = 0x02; + pub const CLOSED: u32 = 0x04; + pub const ACK_SEGMENT: u32 = 0x08; + pub const PACKET_IN_PAST: u32 = 0x10; + pub const TOO_MANY_PACKETS_WARNING: u32 = 0x20; + pub const TOO_MANY_PACKETS: u32 = 0x40; + pub const OTHER: u32 = 0x0; + + // todo: 标记是乱序,没有payload,重组。 +} + +#[derive(Debug, Eq, PartialEq, Clone)] +#[allow(dead_code)] +pub enum TcpStatus { + Closed = 0, + Listen, + SynSent, + SynRcv, + Established, + Closing, + CloseWait, + FinWait1, + FinWait2, + LastAck, + TimeWait, +} + +impl Default for TcpStatus { + fn default() -> Self { + TcpStatus::Closed + } +} + +enum TcpFlags { + FIN = 0x01, + SYN = 0x02, + RST = 0x04, + PSH = 0x08, + ACK = 0x10, + URG = 0x20, +} + +/* -------------------------------------------------------------------------- */ +/* iter */ +/* -------------------------------------------------------------------------- */ + +pub struct TcpIterator<'a> { + segments: IntoIter<(u32, &'a CopiedRawPacket)>, +} + +impl<'a> Iterator for TcpIterator<'a> { + type Item = (u32, RawPacket<'a>); + + fn next(&mut self) -> Option { + while let Some((index, packet)) = self.segments.next() { + let mut ret_encap = Vec::new(); + for encap in &packet.encapsulation { + match encap { + CopiedEncapsulation::L2_ETH(l2, seg) => { + ret_encap.push(Encapsulation::L2_ETH(l2.clone(), seg.as_slice())); + } + CopiedEncapsulation::L3_IP4(ipv4, seg) => { + ret_encap.push(Encapsulation::L3_IP4(ipv4.clone(), seg.as_slice())); + } + CopiedEncapsulation::L3_IP6(ipv6, seg) => { + ret_encap.push(Encapsulation::L3_IP6(ipv6.clone(), seg.as_slice())); + } + CopiedEncapsulation::L4_TCP(tcp, seg) => { + ret_encap.push(Encapsulation::L4_TCP(tcp.clone(), seg.as_slice())); + } + CopiedEncapsulation::L4_UDP(udp, seg) => { + ret_encap.push(Encapsulation::L4_UDP(udp.clone(), seg.as_slice())); + } + CopiedEncapsulation::L7_DNS(dns, seg) => { + ret_encap.push(Encapsulation::L7_DNS(dns.clone(), seg.as_slice())); + } + CopiedEncapsulation::L7_HTTP(http, seg) => { + ret_encap.push(Encapsulation::L7_HTTP(http.clone(), seg.as_slice())); + } + CopiedEncapsulation::Unsupported(seg) => { + ret_encap.push(Encapsulation::Unsupported(seg.as_slice())); + } + } + } + + return Some((index, RawPacket { + encapsulation: ret_encap, + orig_data: packet.orig_data.as_slice(), + orig_len: packet.orig_len, + })); + } + None + } +} + +// since the pub encapsulation has many reference of the original packet buffer, we have to copy them first +#[allow(non_camel_case_types)] +#[derive(Debug, Clone)] +enum CopiedEncapsulation { + L2_ETH(EthernetFrame, Vec), + L3_IP4(IPv4Header, Vec), + L3_IP6(IPv6Header, Vec), + L4_TCP(TcpHeader, Vec), + L4_UDP(UdpHeader, Vec), + L7_DNS(DNS_MESSAGE, Vec), + L7_HTTP(HTTP_MESSAGE, Vec), + Unsupported(Vec), +} + +#[derive(Debug, Clone)] +struct CopiedRawPacket { + encapsulation: Vec, + + orig_data: Vec, + orig_len: u32, +} + +impl From> for CopiedEncapsulation { + fn from(encap: Encapsulation<'_>) -> Self { + match encap { + Encapsulation::L2_ETH(l2, bytes) => CopiedEncapsulation::L2_ETH(l2, bytes.to_vec()), + Encapsulation::L3_IP4(ipv4, bytes) => CopiedEncapsulation::L3_IP4(ipv4, bytes.to_vec()), + Encapsulation::L3_IP6(ipv6, bytes) => CopiedEncapsulation::L3_IP6(ipv6, bytes.to_vec()), + Encapsulation::L4_TCP(tcp, bytes) => CopiedEncapsulation::L4_TCP(tcp, bytes.to_vec()), + Encapsulation::L4_UDP(udp, bytes) => CopiedEncapsulation::L4_UDP(udp, bytes.to_vec()), + Encapsulation::L7_DNS(dns, bytes) => CopiedEncapsulation::L7_DNS(dns, bytes.to_vec()), + Encapsulation::L7_HTTP(http, bytes) => CopiedEncapsulation::L7_HTTP(http, bytes.to_vec()), + Encapsulation::Unsupported(bytes) => CopiedEncapsulation::Unsupported(bytes.to_vec()), + } + } +} + +impl From<&RawPacket<'_>> for CopiedRawPacket { + fn from(packet: &RawPacket) -> Self { + CopiedRawPacket { + encapsulation: packet.encapsulation.clone().into_iter().map(CopiedEncapsulation::from).collect(), + orig_data: packet.orig_data.to_vec(), + orig_len: packet.orig_len, + } + } +} + +fn raw_packet_convert_to_my_packet(raw_packet: &RawPacket<'_>) -> Result { + let mut payload = Vec::new(); + let mut ipv4_header = Option::None; + let mut tcp_header = Option::None; + for encapsulation in &raw_packet.encapsulation { + match encapsulation { + Encapsulation::L3_IP4(ipv4, _) => { + ipv4_header = Some(ipv4); + } + Encapsulation::L4_TCP(tcp, data) => { + tcp_header = Some(tcp); + payload = data.to_vec(); + } + _ => {} + } + } + if ipv4_header.is_none() || tcp_header.is_none() { + return Err(TcpSessionErr::PacketNotV4Tcp); + } + + Ok(TcpPacket { + payload, + src_ip: ipv4_header.unwrap().source_address, + dst_ip: ipv4_header.unwrap().dest_address, + tcp_header: tcp_header.unwrap().clone(), + raw_packet: CopiedRawPacket::from(raw_packet), + }) +} + +/* -------------------------------------------------------------------------- */ +/* stream */ +/* -------------------------------------------------------------------------- */ + +#[derive(Debug, Clone)] +struct TcpPacket { + payload : Vec, + src_ip: Ipv4Addr, + dst_ip: Ipv4Addr, + tcp_header : TcpHeader, + + raw_packet: CopiedRawPacket, +} + +impl TcpPacket { + fn get_sequence(&self) -> u32 { + self.tcp_header.seq_num + } + fn get_acknowledgement(&self) -> u32 { + self.tcp_header.ack_num + } + fn has_flag(&self, flag: TcpFlags) -> bool { + match flag { + TcpFlags::URG => self.tcp_header.flag_urg, + TcpFlags::ACK => self.tcp_header.flag_ack, + TcpFlags::PSH => self.tcp_header.flag_psh, + TcpFlags::RST => self.tcp_header.flag_rst, + TcpFlags::SYN => self.tcp_header.flag_syn, + TcpFlags::FIN => self.tcp_header.flag_fin, + } + } + fn payload(&self) -> &[u8] { + self.payload.as_slice() + } + fn get_timestamp(&self) -> Duration { // todo: 感觉这个duration没什么用,改成u32 吧,单位是秒 + let mut time_val:u32 = 0; + if let Some(options) = &self.tcp_header.options { + for option in options { + if let TcpOption::TIMESTAMPS{length:_, ts_value, ts_reply:_} = option { + time_val = *ts_value; + } + } + } + Duration::new(time_val, 0) + } +} + +#[derive(Debug)] +struct TcpSegment { + rel_seq: Wrapping, + rel_ack: Wrapping, + payload: Vec, + + tcp_header: TcpHeader, + + raw_packet: CopiedRawPacket, +} + +impl TcpSegment { + /// Return the offset of the overlapping area if `self` (as left) overlaps on `right` + fn overlap_offset(&self, right: &TcpSegment) -> Option { + let next_seq = self.rel_seq + Wrapping(self.payload.len() as u32); + if next_seq > right.rel_seq { + let overlap_offset = (right.rel_seq - self.rel_seq).0 as usize; + Some(overlap_offset) + } else { + None + } + } + + /// Splits the segment into two at the given offset. + /// + /// # Panics + /// + /// Panics if `offset > self.payload.len()` + fn split_off(&mut self, offset: usize) -> TcpSegment { + assert!(offset < self.payload.len()); + let remaining = self.payload.split_off(offset); + let rel_seq = self.rel_seq + Wrapping(offset as u32); + TcpSegment { + payload: remaining, + rel_seq, + rel_ack: self.rel_ack, + raw_packet: self.raw_packet.clone(), + tcp_header: self.tcp_header.clone(), + } + } +} + +struct TcpPeer { + // Initial Seq number (absolute) + isn: Wrapping, + // Initial Ack number (absolute) + ian: Wrapping, + // Next Seq number + next_rel_seq: Wrapping, + // Last acknowledged number + last_rel_ack: Wrapping, + // Connection state + status: TcpStatus, + // The current list of segments that this peer is about to sent (ordered by rel_seq) + segments: VecDeque, + addr: Ipv4Addr, + port: u16, + + window_size: u32, + used_window_size: u32, +} + +impl TcpPeer { + fn insert_sorted(&mut self, s: TcpSegment) { + for (n, item) in self.segments.iter().enumerate() { + if item.rel_seq > s.rel_seq { + self.segments.insert(n, s); + return; + } + } + self.segments.push_back(s); + } +} + +#[derive(Debug)] +struct TcpStream { + pub client: TcpPeer, + pub server: TcpPeer, + pub status: TcpStatus, + // from packet.option or timeval passed by api argument. Used to check timeout. + // the free of session is NOT decided by this value. The api user should decide it. + pub last_seen_ts: Duration, +} + +#[derive(Debug)] +pub struct TcpConnection { + stream: TcpStream, + + packets_sent_by_client: Vec>, + packets_sent_by_server: Vec>, + max_packets: usize, + max_warning_packets: usize, + + pub timeout: Duration, // todo: 当前使用默认值 +} + +impl TcpPeer { + fn new(addr: &Ipv4Addr, port: u16) -> Self { + TcpPeer { + isn: Wrapping(0), + ian: Wrapping(0), + next_rel_seq: Wrapping(0), + last_rel_ack: Wrapping(0), + status: TcpStatus::Closed, + segments: VecDeque::new(), + addr: *addr, + port, + window_size: 0, + used_window_size: 0, + } + } +} + +impl TcpStream { + pub fn new(packet: &TcpPacket) -> Self { + TcpStream { + client: TcpPeer::new(&packet.src_ip, packet.tcp_header.source_port), + server: TcpPeer::new(&packet.dst_ip, packet.tcp_header.dest_port), + status: TcpStatus::Closed, + last_seen_ts: packet.get_timestamp(), + } + } + + fn handle_new_connection( + &mut self, + tcp: TcpPacket, + to_server: bool, + ) -> Result>, TcpSessionErr> { + let seq = Wrapping(tcp.get_sequence()); + let ack = Wrapping(tcp.get_acknowledgement()); + + let (src, dst) = if to_server { + (&mut self.client, &mut self.server) + } else { + (&mut self.server, &mut self.client) + }; + + match src.status { + // Client -- SYN --> Server + TcpStatus::Closed => { + if tcp.has_flag(TcpFlags::RST) { + // TODO check if destination.segments must be removed + // client sent a RST, this is expected + return Ok(None); + } + if !tcp.has_flag(TcpFlags::SYN) { + // not a SYN - usually happens at start of pcap if missed SYN + println!("First packet of a TCP stream is not a SYN"); + return Err(TcpSessionErr::WrongFlags); + } + if tcp.has_flag(TcpFlags::ACK) { + println!("First packet is SYN+ACK"); + return Err(TcpSessionErr::WrongFlags); + } + src.isn = seq; + src.next_rel_seq = Wrapping(1); + dst.ian = seq; + dst.window_size = cal_total_window_size(&tcp); // the server sliding window size(client receiving window size) + self.status = TcpStatus::SynSent; + src.status = TcpStatus::SynSent; + dst.status = TcpStatus::Listen; + + if !tcp.payload().is_empty() { + println!("Data in handshake SYN"); + // https://stackoverflow.com/questions/37994131/send-tcp-syn-packet-with-payload + // it is possible to have data in SYN, just queue it(the src window size is 0 currently) + let segment = TcpSegment { + rel_seq: seq - src.isn, + rel_ack: ack - dst.isn, + payload: tcp.payload().to_vec(), + tcp_header: tcp.tcp_header, + raw_packet: tcp.raw_packet, + }; + queue_segment(src, segment); + } + } + // Server -- SYN+ACK --> Client + TcpStatus::Listen => { + if !tcp.has_flag(TcpFlags::SYN) && !tcp.has_flag(TcpFlags::ACK) { + println!("Not a SYN + ACK"); + return Err(TcpSessionErr::WrongFlags); + } + // if we had data in SYN, add its length + let next_rel_seq = if dst.segments.is_empty() { + Wrapping(1) + } else { // + Wrapping(1) + Wrapping(dst.segments[0].payload.len() as u32) + }; + if ack != dst.isn + next_rel_seq { + println!("NEW/SYN-ACK: ack number is wrong"); + return Err(TcpSessionErr::HandshakeFailed); + } + src.isn = seq; + src.next_rel_seq = Wrapping(1); + dst.ian = seq; + dst.last_rel_ack = Wrapping(1); + dst.window_size = cal_total_window_size(&tcp); // the client sliding window size(server receiving window size) + + src.status = TcpStatus::SynRcv; + self.status = TcpStatus::SynRcv; + + // do not push data if we had some in SYN, it will be done after handshake succeeds + } + // Client -- ACK --> Server + TcpStatus::SynSent => { + if !tcp.has_flag(TcpFlags::ACK) { + if tcp.has_flag(TcpFlags::SYN) { + // can be a SYN resend + if seq == src.isn && ack.0 == 0 { + println!("SYN resend - ignoring"); + return Ok(None); + } + // can be a disordered handshake (receive S after SA) + if seq + Wrapping(1) == dst.ian { + println!("Likely received SA before S - ignoring"); + return Ok(None); + } + } + println!("Not an ACK"); + } + + if ack != dst.isn + Wrapping(1) { + println!("NEW/ACK: ack number is wrong"); + return Err(TcpSessionErr::HandshakeFailed); + } + src.status = TcpStatus::Established; + dst.status = TcpStatus::Established; + dst.last_rel_ack = Wrapping(1); + self.status = TcpStatus::Established; + // do we have data ? + if !tcp.payload().is_empty() { + println!("the payload len is {}, used window size: {}", tcp.payload().len(), src.used_window_size); + if src.used_window_size + tcp.payload().len() as u32 > src.window_size { + println!("NEW/ACK: received data but window is full"); + return Err(TcpSessionErr::SidePeerWindowFull); + } + src.used_window_size += tcp.payload().len() as u32; + + let segment = TcpSegment { + rel_seq: seq - src.isn, + rel_ack: ack - dst.isn, + payload: tcp.payload().to_vec(), // XXX data cloned here + tcp_header: tcp.tcp_header, + raw_packet: tcp.raw_packet, + }; + queue_segment(src, segment); + } + } + TcpStatus::SynRcv => { + // we received something while in SYN_RCV state - we should only have sent ACK + // this could be a SYN+ACK retransmit + if tcp.has_flag(TcpFlags::SYN) && tcp.has_flag(TcpFlags::ACK) { + // XXX compare SEQ numbers? + // ignore + return Ok(None); + } + println!("Received unexpected data in SYN_RCV state"); + return Err(TcpSessionErr::WrongFlags); + } + _ => unreachable!(), + } + Ok(None) + } + + fn handle_established_connection( + &mut self, + tcp: TcpPacket, + to_server: bool, + ) -> Result>, TcpSessionErr> { + let (origin, destination) = if to_server { + (&mut self.client, &mut self.server) + } else { + (&mut self.server, &mut self.client) + }; + + let rel_seq = Wrapping(tcp.get_sequence()) - origin.isn; + let rel_ack = Wrapping(tcp.get_acknowledgement()) - destination.isn; + let has_ack = tcp.has_flag(TcpFlags::ACK); // get it before borrowing tcp + + println!("EST: payload len={}", tcp.payload().len()); + println!( + " Tcp rel seq {} ack {} next seq {}", + rel_seq, + rel_ack, + origin.next_rel_seq + ); + + if !tcp.has_flag(TcpFlags::ACK) && tcp.get_acknowledgement() != 0 { + println!( + "Established state packet without ACK (broken TCP implementation or attack)", + ); + // ignore segment + return Err(TcpSessionErr::WrongFlags); + } + println!("the payload len is {}, used window size: {}, orin win size: {}", tcp.payload().len(), origin.used_window_size, origin.window_size); + if origin.used_window_size + tcp.payload().len() as u32 > origin.window_size { + println!("EST: received data but window is full"); + return Err(TcpSessionErr::SidePeerWindowFull); + } + origin.used_window_size += tcp.payload().len() as u32; + + let segment = TcpSegment { + rel_seq, + rel_ack, + payload: tcp.payload().to_vec(), // XXX data cloned here + tcp_header: tcp.tcp_header, + raw_packet: tcp.raw_packet, + }; + queue_segment(origin, segment); + + // if there is a ACK, check & send segments on the *other* side + let ret = if has_ack { + send_peer_segments(destination, rel_ack)? + } else { + None + }; + + println!( + " PEER EST rel next seq {} last_ack {}", + destination.next_rel_seq, + destination.last_rel_ack, + ); + + Ok(ret) + } + + fn handle_closing_connection( + &mut self, + tcp: TcpPacket, + to_server: bool, + ) -> Result>, TcpSessionErr> { + let (origin, destination) = if to_server { + (&mut self.client, &mut self.server) + } else { + (&mut self.server, &mut self.client) + }; + + let rel_seq = Wrapping(tcp.get_sequence()) - origin.isn; + let rel_ack = Wrapping(tcp.get_acknowledgement()) - destination.isn; + let has_ack = tcp.has_flag(TcpFlags::ACK); + let has_fin = tcp.has_flag(TcpFlags::FIN); + + let ret = if has_ack { + println!("ACKing segments up to {}", rel_ack); + send_peer_segments(destination, rel_ack)? + } else { + if tcp.get_acknowledgement() != 0 { + println!( + "EST/ packet without ACK (broken TCP implementation or attack)", + ); + // ignore segment + return Err(TcpSessionErr::WrongFlags); + } + None + }; + if tcp.has_flag(TcpFlags::RST) { + // if we get a RST, check the sequence number and remove matching segments + // todo: 这里是唯一用到了segment::rel_ack的地方,我看不懂这里的逻辑,不过既然我都不管flag 了,那么这个东西自然也不是我处理。 + // 这个地方相当神秘,因为s.rel_ack 的含义是之前这个包发出去的时候,sender 顺便给对方的ack,和segment本身的时序无关。 + // 可以理解它的意思是,这个rst包的sender后悔了,说我想重发,你把那些回退掉。但是一般来说,正常理解中的rst会中断Session,所以就很迷惑。 + // 不管怎么样,在最新的一版里,我把这个删掉了 + destination.segments.retain(|s| s.rel_ack != rel_seq); + println!( + "RST: {} remaining (undelivered) segments DESTINATION after removal", + destination.segments.len() + ); + origin.status = TcpStatus::Closed; // XXX except if ACK ? + return Ok(ret); + } + let mut ret = Ok(ret); + println!("the payload len is {}, used window size: {}", tcp.payload().len(), origin.used_window_size); + if origin.used_window_size + tcp.payload().len() as u32 > origin.window_size { + println!("EST: received data but window is full"); + ret = Err(TcpSessionErr::SidePeerWindowFull); + } else { + origin.used_window_size += tcp.payload().len() as u32; + // queue segment (even if FIN, to get correct seq numbers) + let rel_seq = Wrapping(tcp.get_sequence()) - origin.isn; + let rel_ack = Wrapping(tcp.get_acknowledgement()) - destination.isn; + let segment = TcpSegment { + rel_seq, + rel_ack, + payload: tcp.payload().to_vec(), // XXX data cloned here + tcp_header: tcp.tcp_header, + raw_packet: tcp.raw_packet, + }; + queue_segment(origin, segment); + } + + match origin.status { + TcpStatus::Established => { + // we know there is a FIN (tested in TcpConnection::update) + origin.status = TcpStatus::FinWait1; + destination.status = TcpStatus::CloseWait; // we are not sure it was received + } + TcpStatus::CloseWait => { + if !has_fin { + // if only an ACK, do nothing and stay in CloseWait status + if has_ack { + // println!("destination status: {:?}", destination.status); + if destination.status == TcpStatus::FinWait1 { + destination.status = TcpStatus::FinWait2; + } + } else { + println!("Origin should have sent a FIN and/or ACK"); + ret = Err(TcpSessionErr::WrongFlags); + } + } else { + origin.status = TcpStatus::LastAck; + // println!("destination status: {:?}", destination.status); + if has_ack || destination.status == TcpStatus::FinWait2 { + destination.status = TcpStatus::TimeWait; + } else { + destination.status = TcpStatus::Closing; + } + } + } + TcpStatus::TimeWait => { + // only an ACK should be sent (XXX nothing else, maybe PSH) + if has_ack { + // this is the end! + origin.status = TcpStatus::Closed; + destination.status = TcpStatus::Closed; + } + } + _ => { + println!( + "Unhandled closing transition: origin host {} status {:?}", + origin.addr, origin.status + ); + println!( + " dest host {} status {:?}", + destination.addr, destination.status + ); + } + } + + println!( + "TCP connection closing, {} remaining (undelivered) segments", + origin.segments.len() + ); + // println + for (n, s) in origin.segments.iter().enumerate() { + println!( + " s[{}]: seq={} len={}", + n, + s.rel_seq.0, + s.payload.len(), + ); + } + + return ret; + } + + // force expiration (for ex after timeout) of this stream + fn expire(&mut self) { + self.client.status = TcpStatus::Closed; + self.server.status = TcpStatus::Closed; + } +} // TcpStream + +fn queue_segment(peer: &mut TcpPeer, segment: TcpSegment) { + if segment.payload.is_empty() { + return; + } + //todo: 老代码有一个 EARLY_DETECT_OVERLAP 不知道干嘛的 + + if peer.segments.is_empty() { + println!("Pushing segment (front)"); + peer.segments.push_front(segment); + return; + } + + println!("Adding segment"); + peer.insert_sorted(segment); +} + +// let the peer send segments in its queue, update ack numbers, and pop segments that were sent +fn send_peer_segments(peer: &mut TcpPeer, rel_ack: Wrapping) -> Result>, TcpSessionErr> { + println!( + "Trying to send segments for {}:{} up to {} (last ack: {})", + peer.addr, + peer.port, + rel_ack, + peer.last_rel_ack + ); + if rel_ack == peer.last_rel_ack { + println!("re-acking last data, doing nothing"); + return Ok(None); + } + if peer.segments.is_empty() { + return Ok(None); + } + + // is ACK acceptable? + if rel_ack < peer.last_rel_ack { + println!("ACK request for already ACKed data (ack < last_ack)"); + return Err(TcpSessionErr::UnexpectedAckNumber); + } + + // check consistency of segment ACK numbers + order and/or missing fragments and/or overlap + + let mut acked = Vec::new(); + + while !peer.segments.is_empty() { + let segment = &peer.segments[0]; + if rel_ack <= segment.rel_seq { + // if packet is in the past (strictly less), we don't care + break; + } + + // safety: segments is just tested above + let mut segment = peer.segments.pop_front().unwrap(); + + if rel_ack < segment.rel_seq + Wrapping(segment.payload.len() as u32) { + println!("ACK for part of buffer"); + // split data and insert new dummy segment + let acked_len = (rel_ack - segment.rel_seq).0 as usize; + let new_segment = segment.split_off(acked_len); + println!( + "insert new segment from {} len {}", + new_segment.rel_ack, + new_segment.payload.len() + ); + peer.insert_sorted(new_segment); + } + + handle_overlap_linux(peer, &mut segment); + adjust_seq_numbers(peer, &segment); + peer.used_window_size -= segment.payload.len() as u32; + + if !segment.payload.is_empty() { + acked.push(segment); + } + } + + if peer.next_rel_seq != rel_ack { + // missed segments, or maybe received FIN ? + println!( + "TCP ACKed unseen segment next_seq {} != ack {} (Missed segments?)", + peer.next_rel_seq, rel_ack + ); + // TODO 这个正确吗?如果因为这个,把整个老segment 都删了肯定不对,具体怎么处理错误看需求吧 + // return Err(TcpSessionErr::UnexpectedAckNumber); + } + + peer.last_rel_ack = rel_ack; + Ok(Some(acked)) +} + +const FIRST_WINS: bool = false; + +// handle overlapping segments, using a linux-like policy +// Linux favors an original segment, EXCEPT when the subsequent begins before the original, +//or the subsequent segment begins the same and ends after the original segment. +#[allow(dead_code)] +fn handle_overlap_linux(peer: &mut TcpPeer, segment: &mut TcpSegment) { + // loop while segment has overlap + while let Some(next) = peer.segments.front() { + if let Some(overlap_offset) = segment.overlap_offset(next) { + println!( + "overlaps next candidate (at offset={})", + overlap_offset + ); + // we will modify the subsequent segment (next) + // safety: element presence was tested in outer loop + let next = peer.segments.pop_front().unwrap(); + + // split next + let overlap_size = segment.payload.len() - overlap_offset; + let min_overlap_size = std::cmp::min(overlap_size, next.payload.len()); + // compare overlap area + if next.payload[..min_overlap_size] + != segment.payload[overlap_offset..overlap_offset + min_overlap_size] + { + println!("Overlap area differs!"); + } + if overlap_size >= next.payload.len() { + // subsequent segment starts after and is smaller, so drop it + drop(next); + continue; + } + // otherwise, split next into left and right, drop left and accept right + let mut left = next; + let right = left.split_off(overlap_size); + // to accept right, merge it into segment + segment.payload.extend_from_slice(&right.payload); + } else { + // println!("no overlap, break"); + break; + } + } +} + +fn adjust_seq_numbers(origin: &mut TcpPeer, segment: &TcpSegment) { + if !segment.payload.is_empty() { + // adding length is wrong in case of overlap + // origin.next_rel_seq += Wrapping(segment.payload.len() as u32); + origin.next_rel_seq = segment.rel_seq + Wrapping(segment.payload.len() as u32); + } + + if segment.tcp_header.flag_fin { + // println!("Segment has FIN"); + origin.next_rel_seq += Wrapping(1); + } +} + +impl TcpConnection { + pub(crate) fn try_new(packet: &RawPacket) -> Result { + let simple_packet = raw_packet_convert_to_my_packet(packet)?; + Self::_try_new(simple_packet) + } + + pub(crate) fn update(&mut self, packet: &RawPacket) -> Result { + let simple_packet = raw_packet_convert_to_my_packet(packet)?; + self._update(simple_packet) + } + + /// Like std::Iter trait, but has an argument to choose which side to iterate + pub(crate) fn iter(&self, sent_by_client: bool) -> TcpIterator { + let target = { + if sent_by_client { + &self.packets_sent_by_client + } else { + &self.packets_sent_by_server + } + }; + + let mut ret: Vec<(u32, &CopiedRawPacket)> = Vec::new(); + for (index, packet_vec) in target.iter().enumerate() { + for packet in packet_vec.iter() { + ret.push((index as u32, &packet)); + } + } + + TcpIterator { + segments: ret.into_iter(), + } + } + + fn _try_new(packet: TcpPacket) -> Result { + let mut connection = TcpConnection { + stream: TcpStream::new(&packet), + packets_sent_by_client: Vec::new(), + packets_sent_by_server: Vec::new(), + timeout: DEFAULT_TIMEOUT, + max_packets: DEFAULT_MAX_PACKETS, + max_warning_packets: DEFAULT_MAX_PACKETS >> 1, + }; + connection.stream.last_seen_ts = packet.get_timestamp(); + + if !packet.has_flag(TcpFlags::SYN) { + return Err(TcpSessionErr::WrongFlags); + } + + connection._update(packet)?; + if connection.stream.client.status == TcpStatus::Closed || connection.stream.server.status == TcpStatus::Closed { + return Err(TcpSessionErr::NewConnectionFailed); + } + Ok(connection) + } + + fn _update(&mut self, tcp: TcpPacket) -> Result { + let stream = &mut self.stream; + println!("stream state: {:?}", stream.status); + + let client_status_before = stream.client.status.clone(); + let server_status_before = stream.server.status.clone(); + let mut ok_ret = TcpSessionOk::OTHER; + + // check time delay with previous packet before updating + let packet_ts = tcp.get_timestamp(); + if stream.last_seen_ts > packet_ts { + println!("packet received in past"); + ok_ret |= TcpSessionOk::PACKET_IN_PAST; + } else if packet_ts - stream.last_seen_ts > self.timeout { + println!("TCP stream received packet after timeout"); + stream.expire(); + return Err(TcpSessionErr::ExpiredSession); + } + stream.last_seen_ts = packet_ts; + + // get origin and destination + let to_server = tcp.dst_ip == stream.server.addr && + tcp.tcp_header.dest_port == stream.server.port; + println!("to_server: {}", to_server); + let (origin, _destination) = if to_server { + (&mut stream.client, &mut stream.server) + } else { + (&mut stream.server, &mut stream.client) + }; + + println!("origin: {}:{} status {:?}", + origin.addr, + origin.port, + origin.status + ); + let sent_packet = + match origin.status { + TcpStatus::Closed | TcpStatus::Listen | TcpStatus::SynSent | TcpStatus::SynRcv => { + stream.handle_new_connection(tcp, to_server) + } + TcpStatus::Established => { + // check for close request + if tcp.has_flag(TcpFlags::FIN) || tcp.has_flag(TcpFlags::RST) { + stream.handle_closing_connection(tcp, to_server) + } else { + stream.handle_established_connection(tcp, to_server) + } + } + _ => stream.handle_closing_connection(tcp, to_server), + }?; + + if let Some(sent_packet) = sent_packet { + ok_ret = ok_ret | TcpSessionOk::ACK_SEGMENT; + let send_queue = if to_server { // ack packet, so the previous packet if from the other side. + &mut self.packets_sent_by_server + } else { + &mut self.packets_sent_by_client + }; + println!("czz send_queue len: {}", send_queue.len()); + if send_queue.len() >= self.max_packets { + ok_ret |= TcpSessionOk::TOO_MANY_PACKETS; + } else { + if send_queue.len() >= self.max_warning_packets { + ok_ret |= TcpSessionOk::TOO_MANY_PACKETS_WARNING; + } + send_queue.push(sent_packet.into_iter().map(|s| s.raw_packet).collect()); + } + } + + if client_status_before != stream.client.status || server_status_before != stream.server.status { + println!("status changed: {:?} -> {:?} / {:?} -> {:?}", + client_status_before, stream.client.status, server_status_before, stream.server.status + ); + if stream.client.status == TcpStatus::Established { + ok_ret = ok_ret | TcpSessionOk::ESTABLISHED; + } else if stream.client.status == TcpStatus::Closed || stream.server.status == TcpStatus::Closed { + ok_ret |= TcpSessionOk::CLOSED; + } else if client_status_before == TcpStatus::Established { + ok_ret = TcpSessionOk::CLOSING; + } + } + Ok(ok_ret) + } + + // todo: refresh (删除所有的数据包,但是保留状态) + +} + +impl fmt::Debug for TcpPeer { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "Peer: {}:{}", self.addr, self.port)?; + writeln!(f, " status: {:?}", self.status)?; + writeln!(f, " isn: 0x{:x} ian: 0x{:x}", self.isn, self.ian)?; + writeln!(f, " next_rel_seq: {}", self.next_rel_seq)?; + writeln!(f, " last_rel_ack: {}", self.last_rel_ack)?; + writeln!(f, " #segments: {}", self.segments.len())?; + for (n, s) in self.segments.iter().enumerate() { + writeln!( + f, + " s[{}]: rel_seq={} len={}", + n, + s.rel_seq, + s.payload.len(), + )?; + } + Ok(()) + } +} + +fn cal_total_window_size(tcp: &TcpPacket) -> u32 { + println!("cal_total_window_size, window: {}", tcp.tcp_header.window); + let mut scale = 0; + if let Some(options) = &tcp.tcp_header.options { + for option in options { + if let TcpOption::WSCALE{length, shift_count} = option { + scale = *shift_count; + } + } + } + + let total_window_size = tcp.tcp_header.window as u32; + + total_window_size << scale +} + +#[cfg(test)] +mod tests { + use std::vec; + use crate::protocol::ip::IPProtocol; + + use super::*; + + static SLICE_DUMMY:&[u8] = &[42,42,42]; + + #[derive(Debug, Clone)] + enum PeerRole { + Client, + Server, + } + #[derive(Debug, Clone)] + struct PeerInTest { + addr: Ipv4Addr, + /// println: port + port: u16, + role: PeerRole, + } + + fn new_raw<'a>(from: &PeerInTest, to: &PeerInTest, seq_num: u32, ack_num: u32, + has_ack: bool, has_syn: bool, has_rst: bool, has_fin: bool, segment: &'a [u8]) + -> RawPacket<'a> { + let src_ip = from.addr; + let dst_ip = to.addr; + let header = TcpHeader { + source_port: from.port, + dest_port: to.port, + seq_num, + ack_num, + data_offset: 0, + reserved: 0, + flag_urg: false, + flag_ack: has_ack, + flag_psh: false, + flag_rst: has_rst, + flag_syn: has_syn, + flag_fin: has_fin, + window: 65535, + checksum: 0, + urgent_ptr: 0, + options: None, + }; + + let ip_header = IPv4Header { + version: 4, + ihl: 5, + tos: 0, + length: 0, + id: 0, + flags: 0, + frag_offset: 0, + ttl: 0, + protocol: IPProtocol::TCP, + checksum: 0, + source_address: src_ip, + dest_address: dst_ip, + }; + + let encap1 : Encapsulation = Encapsulation::L3_IP4(ip_header, SLICE_DUMMY); + let encap2 : Encapsulation = Encapsulation::L4_TCP(header, segment); + let encap_vec = vec![encap1, encap2]; + + RawPacket { + orig_data: SLICE_DUMMY, + orig_len: SLICE_DUMMY.len() as u32, + encapsulation: encap_vec, + } + } + + fn raw_packet_set_timestamp(packet: &mut RawPacket, ts: u32) { + if let Encapsulation::L4_TCP(ref mut header, _) = packet.encapsulation[1] { + if header.options.is_none() { + header.options = Some(Vec::new()); + } + let options = header.options.as_mut().unwrap(); + options.push(TcpOption::TIMESTAMPS{length: 10, ts_value: ts, ts_reply: 0}); + } else { + panic!("raw_packet_set_timestamp: not a TCP packet"); + } + } + + fn raw_packet_set_window(packet: &mut RawPacket, window: u16, scale: u8) { + if let Encapsulation::L4_TCP(ref mut header, _) = packet.encapsulation[1] { + header.window = window; + + if header.options.is_none() { + header.options = Some(Vec::new()); + } + let options = header.options.as_mut().unwrap(); + options.push(TcpOption::WSCALE{length: 3, shift_count: scale}); + } else { + panic!("raw_packet_set_window: not a TCP packet"); + } + } + + const CLIENT: PeerInTest = PeerInTest { + addr: Ipv4Addr::new(192, 168, 1, 1), + port: 1234, + role: PeerRole::Client, + }; + const SERVER: PeerInTest = PeerInTest { + addr: Ipv4Addr::new(192, 168, 1, 2), + port: 80, + role: PeerRole::Server, + }; + + #[test] + fn a_very_normal_connection() { + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + assert!(connection.stream.client.status == TcpStatus::SynSent); + assert!(connection.stream.server.status == TcpStatus::Listen); + + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + let ret = connection.update(&packet_handshake2).unwrap(); + assert!(connection.stream.client.status == TcpStatus::SynSent); + assert!(connection.stream.server.status == TcpStatus::SynRcv); + assert!(ret == TcpSessionOk::OTHER); + + let packet_handshake3 = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[]); + let ret = connection.update(&packet_handshake3).unwrap(); + assert!(connection.stream.client.status == TcpStatus::Established); + assert!(connection.stream.server.status == TcpStatus::Established); + assert!(ret == TcpSessionOk::ESTABLISHED); + + let packet_established_from_client = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1, 2, 3]); + let ret = connection.update(&packet_established_from_client).unwrap(); + assert!(ret == TcpSessionOk::OTHER); + assert!(connection.stream.client.segments.len() == 1); + + let packet_established_server_response = new_raw(&SERVER, &CLIENT, 1, 4, true, false, false, false, &[]); + let ret = connection.update(&packet_established_server_response).unwrap(); + assert!(connection.stream.client.segments.len() == 0); + assert!(ret == TcpSessionOk::ACK_SEGMENT); + let seg: Vec<_> = connection.iter(true).collect(); + assert!(seg.len() == 1); + assert!(raw_packet_convert_to_my_packet(&seg[0].1).unwrap().payload == [1, 2, 3]); + + let packet_established_from_server = new_raw(&SERVER, &CLIENT, 1, 4, true, false, false, false, &[4]); + let ret = connection.update(&packet_established_from_server).unwrap(); + assert!(connection.stream.server.segments.len() == 1); + assert!(ret == TcpSessionOk::OTHER); + + let packet_established_client_response = new_raw(&CLIENT, &SERVER, 4, 2, true, false, false, false, &[]); + let ret = connection.update(&packet_established_client_response).unwrap(); + assert!(ret == TcpSessionOk::ACK_SEGMENT); + let seg: Vec<_> = connection.iter(false).collect(); + assert!(seg.len() == 1); + assert!(raw_packet_convert_to_my_packet(&seg[0].1).unwrap().payload == [4]); + + assert!(connection.stream.client.status == TcpStatus::Established); + assert!(connection.stream.server.status == TcpStatus::Established); + let packet_close_by_client = new_raw(&CLIENT, &SERVER, 4, 2, true, false, false, true, &[]); + let ret = connection.update(&packet_close_by_client).unwrap(); + assert!(connection.stream.client.status == TcpStatus::FinWait1); + assert!(connection.stream.server.status == TcpStatus::CloseWait); + assert!(ret == TcpSessionOk::CLOSING); + + let packet_close_response_by_server = new_raw(&SERVER, &CLIENT, 2, 5, true, false, false, false, &[]); + let ret = connection.update(&packet_close_response_by_server).unwrap(); + assert!(connection.stream.client.status == TcpStatus::FinWait2); + assert!(connection.stream.server.status == TcpStatus::CloseWait); + assert!(ret == TcpSessionOk::OTHER); + + let packet_close_by_server = new_raw(&SERVER, &CLIENT, 2, 5, true, false, false, true, &[]); + let ret = connection.update(&packet_close_by_server).unwrap(); + assert!(connection.stream.client.status == TcpStatus::TimeWait); + assert!(connection.stream.server.status == TcpStatus::LastAck); + assert!(ret == TcpSessionOk::OTHER); + + let packet_close_response_by_client = new_raw(&CLIENT, &SERVER, 5, 3, true, false, false, false, &[]); + let ret = connection.update(&packet_close_response_by_client).unwrap(); + assert!(connection.stream.client.status == TcpStatus::Closed); + assert!(connection.stream.server.status == TcpStatus::Closed); + assert!(ret == TcpSessionOk::CLOSED); + } + + #[test] + fn several_ordered_segments_in_one_ack() { + // standard handshake + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + connection.update(&packet_handshake2).unwrap(); + let packet_handshake3 = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[]); + connection.update(&packet_handshake3).unwrap(); + + // send 3 segments from client + let packet_established_from_client1 = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1, 2, 3]); + let ret = connection.update(&packet_established_from_client1).unwrap(); + assert!(ret == TcpSessionOk::OTHER); + let packet_established_from_client2 = new_raw(&CLIENT, &SERVER, 4, 1, true, false, false, false, &[4, 5, 6]); + let ret = connection.update(&packet_established_from_client2).unwrap(); + assert!(ret == TcpSessionOk::OTHER); + let packet_established_from_client3 = new_raw(&CLIENT, &SERVER, 7, 1, true, false, false, false, &[7, 8, 9]); + let ret = connection.update(&packet_established_from_client3).unwrap(); + assert!(ret == TcpSessionOk::OTHER); + + // server ack + let packet_established_server_responce = new_raw(&SERVER, &CLIENT, 1, 10, true, false, false, false, &[]); + let ret = connection.update(&packet_established_server_responce).unwrap(); + assert!(ret == TcpSessionOk::ACK_SEGMENT); + let seg: Vec<_> = connection.iter(true).collect(); + assert!(seg.len() == 3); + assert!(seg[0].0 == 0); + assert!(seg[1].0 == 0); + assert!(seg[2].0 == 0); + assert!(raw_packet_convert_to_my_packet(&seg[0].1).unwrap().payload == [1, 2, 3]); + assert!(raw_packet_convert_to_my_packet(&seg[1].1).unwrap().payload == [4, 5, 6]); + assert!(raw_packet_convert_to_my_packet(&seg[2].1).unwrap().payload == [7, 8, 9]); + } + + #[test] + fn several_unordered_segments_in_one_ack() { + // standard handshake + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + connection.update(&packet_handshake2).unwrap(); + let packet_handshake3 = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[]); + connection.update(&packet_handshake3).unwrap(); + + // send 3 segments from client + let packet_established_from_client1 = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1, 2, 3]); + let packet_established_from_client2 = new_raw(&CLIENT, &SERVER, 4, 1, true, false, false, false, &[4, 5, 6]); + let packet_established_from_client3 = new_raw(&CLIENT, &SERVER, 7, 1, true, false, false, false, &[7, 8, 9]); + connection.update(&packet_established_from_client3).unwrap(); // unwrap: check not error + connection.update(&packet_established_from_client1).unwrap(); + connection.update(&packet_established_from_client2).unwrap(); + + // server ack + let packet_established_server_responce = new_raw(&SERVER, &CLIENT, 1, 10, true, false, false, false, &[]); + let ret = connection.update(&packet_established_server_responce).unwrap(); + assert!(ret == TcpSessionOk::ACK_SEGMENT); + let seg: Vec<_> = connection.iter(true).collect(); + assert!(seg.len() == 3); + // assert!(ret.as_ref().unwrap()[0].payload.as_slice() == &[1, 2, 3]); + assert!(seg[0].0 == 0); + assert!(seg[1].0 == 0); + assert!(seg[2].0 == 0); + assert!(raw_packet_convert_to_my_packet(&seg[0].1).unwrap().payload == [1, 2, 3]); + assert!(raw_packet_convert_to_my_packet(&seg[1].1).unwrap().payload == [4, 5, 6]); + assert!(raw_packet_convert_to_my_packet(&seg[2].1).unwrap().payload == [7, 8, 9]); + } + + #[test] + fn wrong_ack_num() { + // standard handshake + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + connection.update(&packet_handshake2).unwrap(); + let packet_handshake3 = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[]); + connection.update(&packet_handshake3).unwrap(); + // send 3 segments from client + let packet_established_from_client1 = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1, 2, 3]); + connection.update(&packet_established_from_client1).unwrap(); + let packet_established_from_client2 = new_raw(&CLIENT, &SERVER, 4, 1, true, false, false, false, &[4, 5, 6]); + connection.update(&packet_established_from_client2).unwrap(); + let packet_established_from_client3 = new_raw(&CLIENT, &SERVER, 7, 1, true, false, false, false, &[7, 8, 9]); + connection.update(&packet_established_from_client3).unwrap(); + + // server ack twice + let packet_established_server_responce = new_raw(&SERVER, &CLIENT, 1, 10, true, false, false, false, &[]); + let ret = connection.update(&packet_established_server_responce).unwrap(); + assert!(ret == TcpSessionOk::ACK_SEGMENT); + let ret = connection.update(&packet_established_server_responce).unwrap(); + assert!(ret == TcpSessionOk::OTHER); + + // server ack with wrong ack number(smaller) + let packet_established_from_client4 = new_raw(&CLIENT, &SERVER, 10, 1, true, false, false, false, &[10, 11, 12]); + connection.update(&packet_established_from_client4).unwrap(); + let packet_established_server_responce_wrong_ack = new_raw(&SERVER, &CLIENT, 1, 7, true, false, false, false, &[]); + let ret = connection.update(&packet_established_server_responce_wrong_ack); + assert!(ret == Err(TcpSessionErr::UnexpectedAckNumber)); + let seg: Vec<_> = connection.iter(true).collect(); + assert!(seg.len() == 3); + + // server ack with ack number(bigger). Bigger ack is Ok. + let packet_established_server_responce_wrong_ack = new_raw(&SERVER, &CLIENT, 1, 20, true, false, false, false, &[]); + let ret = connection.update(&packet_established_server_responce_wrong_ack); + assert!(ret == Ok(TcpSessionOk::ACK_SEGMENT)); + let seg: Vec<_> = connection.iter(true).collect(); + assert!(seg.len() == 4); + assert!(seg[0].0 == 0); + assert!(seg[1].0 == 0); + assert!(seg[2].0 == 0); + assert!(seg[3].0 == 1); + } + + #[test] + // 非error,正常处理,但是返回一个码 + fn err_expired_session() { + const FIRST_TS:u32 = 123456; + + let mut packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + raw_packet_set_timestamp(&mut packet_handshake1, FIRST_TS); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let mut packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + raw_packet_set_timestamp(&mut packet_handshake2, FIRST_TS + 120 * 60 + 1); + let ret = connection.update(&packet_handshake2); + assert!(ret == Err(TcpSessionErr::ExpiredSession)); + } + + #[test] + // 不关心 + fn ok_session_with_almost_expired_sessions() { + let mut ts:u32 = 123456; + + let mut packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + raw_packet_set_timestamp(&mut packet_handshake1, ts); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let mut packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + ts += 120 * 60; + raw_packet_set_timestamp(&mut packet_handshake2, ts); + connection.update(&packet_handshake2).unwrap(); + ts += 120 * 60; + let mut packet_handshake3 = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[]); + raw_packet_set_timestamp(&mut packet_handshake3, ts); + connection.update(&packet_handshake3).unwrap(); + let mut packet_established = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1, 2, 3]); + raw_packet_set_timestamp(&mut packet_established, ts + 120 * 60); + connection.update(&packet_established).unwrap(); + } + + // + #[test] + fn ok_session_with_old_packet() { + let ts:u32 = 123456; + + let mut packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + raw_packet_set_timestamp(&mut packet_handshake1, ts); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let mut packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + raw_packet_set_timestamp(&mut packet_handshake2, ts - 1); + let ret = connection.update(&packet_handshake2).unwrap(); + assert!(ret & TcpSessionOk::PACKET_IN_PAST > 0); + } + + #[test] + fn warn_many_packet() { + // todo: 再研究一下ack number 怎么涨的,把边缘条件弄对,另外server不要发segment了 + // standard handshake + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + connection.update(&packet_handshake2).unwrap(); + let packet_handshake3 = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[]); + connection.update(&packet_handshake3).unwrap(); + + let mut client_seq = 1; + let mut server_ack = 1; + for _ in 0..DEFAULT_MAX_PACKETS / 2 + 1{ + let client_send_in_established = new_raw(&CLIENT, &SERVER, client_seq, 1, true, false, false, false, &[1]); + client_seq += 1; + let server_send_in_established = new_raw(&SERVER, &CLIENT, 1, server_ack, true, false, false, false, &[1]); + server_ack += 1; + connection.update(&client_send_in_established).unwrap(); + let ret = connection.update(&server_send_in_established).unwrap(); + assert!(ret & TcpSessionOk::TOO_MANY_PACKETS_WARNING == 0); // not half full yet + } + + // one more packet + let client_send_in_established = new_raw(&CLIENT, &SERVER, client_seq, 1, true, false, false, false, &[1]); + client_seq += 1; + let server_send_in_established = new_raw(&SERVER, &CLIENT, 1, server_ack, true, false, false, false, &[1]); + server_ack += 1; + connection.update(&client_send_in_established).unwrap(); + let ret = connection.update(&server_send_in_established).unwrap(); + assert!(ret & TcpSessionOk::TOO_MANY_PACKETS_WARNING > 0); // half full + + // more until full + for _ in 0..DEFAULT_MAX_PACKETS / 2 - 1{ + let client_send_in_established = new_raw(&CLIENT, &SERVER, client_seq, 1, true, false, false, false, &[1]); + client_seq += 1; + let server_send_in_established = new_raw(&SERVER, &CLIENT, 1, server_ack, true, false, false, false, &[1]); + server_ack += 1; + connection.update(&client_send_in_established).unwrap(); + let ret = connection.update(&server_send_in_established).unwrap(); + assert!(ret & TcpSessionOk::TOO_MANY_PACKETS == 0); // not full yet + } + + // one more packet + let client_send_in_established = new_raw(&CLIENT, &SERVER, client_seq, 1, true, false, false, false, &[2]); + client_seq += 1; + let server_send_in_established = new_raw(&SERVER, &CLIENT, 1, server_ack, true, false, false, false, &[2]); + server_ack += 1; + connection.update(&client_send_in_established).unwrap(); + let ret = connection.update(&server_send_in_established).unwrap(); + assert!(ret & TcpSessionOk::TOO_MANY_PACKETS > 0); // full + + // check if the last if dropped + let mut seg: Vec<_> = connection.iter(true).collect(); + assert!(seg.len() == DEFAULT_MAX_PACKETS); + + assert!(raw_packet_convert_to_my_packet(&seg.pop().unwrap().1).unwrap().payload == [1]); // the last packet payload is [2] + } + + #[test] + fn err_full_window() { + // standard handshake + let packet_handshake1 = new_raw(&CLIENT, &SERVER, 0, 0, false, true, false, false, &[]); + let mut connection = TcpConnection::try_new(&packet_handshake1).unwrap(); + let mut packet_handshake2 = new_raw(&SERVER, &CLIENT, 0, 1, true, true, false, false, &[]); + raw_packet_set_window(&mut packet_handshake2, 5, 0); // the server will receive 5 bytes, so the client can send 5 byte as most. + connection.update(&packet_handshake2).unwrap(); + let packet_handshake3 = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[]); + connection.update(&packet_handshake3).unwrap(); + + let client_send_in_established = new_raw(&CLIENT, &SERVER, 1, 1, true, false, false, false, &[1, 2, 3, 4, 5]); + let ret = connection.update(&client_send_in_established).unwrap(); + assert!(ret == TcpSessionOk::OTHER); + let client_send_another = new_raw(&CLIENT, &SERVER, 6, 1, true, false, false, false, &[6]); + println!("czzzzzzzz"); + let ret = connection.update(&client_send_another); + println!("czzzzz the ret is {:?}", ret); + assert!(ret == Err(TcpSessionErr::SidePeerWindowFull)); + + // response will clear the window + let server_response = new_raw(&SERVER, &CLIENT, 1, 6, true, false, false, false, &[]); + let ret = connection.update(&server_response).unwrap(); + assert!(ret == TcpSessionOk::ACK_SEGMENT); + let client_send_another = new_raw(&CLIENT, &SERVER, 6, 1, true, false, false, false, &[6]); + let ret = connection.update(&client_send_another).unwrap(); + assert!(ret == TcpSessionOk::OTHER); + } + + // todo: many error flag + // TODO: 回绕 + + +} \ No newline at end of file -- cgit v1.2.3