summaryrefslogtreecommitdiff
path: root/src/tcp.rs
blob: 94b876a9ce48fcf4c172a7325dd0a6d3edde981f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
use std::{
    cell::RefCell,
    io::{self, Read, Write},
    net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream, ToSocketAddrs},
    os::unix::prelude::AsRawFd,
    rc::{Rc, Weak},
    task::Poll,
};

use futures::Stream;
use socket2::{Domain, Protocol, Socket, Type};

use crate::{reactor::get_reactor, reactor::Reactor};

/// TCP 监听器
#[derive(Debug)]
pub struct TcpListener {
    reactor: Weak<RefCell<Reactor>>, // reactor
    listener: StdTcpListener, // 标准库的 TcpListener | 包装一层
}

impl TcpListener {
    /// 绑定地址并返回一个 `TcpListener` 实例
    pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> {
        // 解析地址
        let addr = addr
            .to_socket_addrs()?
            .next()
            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?;

        // 创建 socket
        let domain = if addr.is_ipv6() {
            Domain::IPV6
        } else {
            Domain::IPV4
        };
        let sk = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;

        // 绑定地址并监听
        let addr = socket2::SockAddr::from(addr);
        sk.set_reuse_address(true)?;
        sk.bind(&addr)?;
        sk.listen(1024)?;

        // 将 fd 添加到 reactor 中
        let reactor = get_reactor();
        reactor.borrow_mut().add(sk.as_raw_fd());

        println!("tcp bind with fd {}", sk.as_raw_fd());
        Ok(Self {
            reactor: Rc::downgrade(&reactor), 
            listener: sk.into(),
        })
    }
}
//Stream 流
impl Stream for TcpListener { //TcpStream 和 TcpListener 在这链接
    type Item = std::io::Result<(TcpStream, SocketAddr)>; // 

    fn poll_next(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Option<Self::Item>> {
        match self.listener.accept() {
            Ok((stream, addr)) => Poll::Ready(Some(Ok((stream.into(), addr)))),
            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { // 继续阻塞
                // 修改反应器以注册感兴趣的事件
                let reactor = self.reactor.upgrade().unwrap();
                reactor
                    .borrow_mut()
                    .modify_readable(self.listener.as_raw_fd(), cx); // 可读事件
                Poll::Pending
            }
            Err(e) => std::task::Poll::Ready(Some(Err(e))),
        }
    }
}

/// TCP 流
#[derive(Debug)]
pub struct TcpStream {
    stream: StdTcpStream, // 标准库的 TcpStream | 包装一层
}

impl From<StdTcpStream> for TcpStream {
    // 从标准库的 TcpStream 转换为自定义的 TcpStream
    fn from(stream: StdTcpStream) -> Self {
        let reactor: Rc<RefCell<Reactor>> = get_reactor(); // 获取 reactor
        reactor.borrow_mut().add(stream.as_raw_fd()); // 将 fd 添加到 reactor
        Self { stream } // 返回包装后的 TcpStream
    }
}

impl Drop for TcpStream {
    fn drop(&mut self) {
        // 可变引用
        println!("drop");
        let reactor = get_reactor(); // 获取 reactor
        reactor.borrow_mut().delete(self.stream.as_raw_fd()); // 将 fd 从 reactor 中删除
    }
}

// 为 TcpStream 实现 tokio::io::AsyncRead
impl tokio::io::AsyncRead for TcpStream {
    fn poll_read(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        let fd = self.stream.as_raw_fd(); // 获取 stream 对应的 fd
        unsafe {
            // 将 ReadBuf 转换为 [u8] , stream.read 需要
            let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]);
            println!("read for fd {}", fd);
            match self.stream.read(b) {
                Ok(n) => { // 读取成功
                    println!("read for fd {} done, {}", fd, n);
                    buf.assume_init(n); // 初始化 n 个字节
                    buf.advance(n); // 指针前进 n 个字节
                    Poll::Ready(Ok(())) // 返回结果
                }
                Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 读取失败,且错误类型为 WouldBlock
                    println!("read for fd {} done WouldBlock", fd);
                    // 修改反应器以注册感兴趣的事件
                    let reactor = get_reactor(); // 获取 reactor
                    reactor
                        .borrow_mut()
                        .modify_readable(self.stream.as_raw_fd(), cx); // 注册到 reactor 可读事件
                    Poll::Pending // 等待
                }
                Err(e) => {
                    println!("read for fd {} done err", fd);
                    Poll::Ready(Err(e)) // 结束(错误)
                }
            }
        }
    }
}

impl tokio::io::AsyncWrite for TcpStream {
    fn poll_write(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, io::Error>> {
        match self.stream.write(buf) {
            Ok(n) => Poll::Ready(Ok(n)), // 写入成功
            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 写入失败,且错误类型为 WouldBlock
                let reactor: Rc<RefCell<Reactor>> = get_reactor();
                reactor
                    .borrow_mut()
                    .modify_writable(self.stream.as_raw_fd(), cx); // 注册可写事件
                Poll::Pending
            }
            Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息)
        }
    }

    fn poll_flush( // 或许是仅占位吧.
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Result<(), io::Error>> {
        Poll::Ready(Ok(()))
    }

    fn poll_shutdown( // 关闭时,将 stream 的写入关闭
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Result<(), io::Error>> {
        self.stream.shutdown(std::net::Shutdown::Write)?; // 关闭写入,出错则传递错误
        Poll::Ready(Ok(()))
    }
}