summaryrefslogtreecommitdiff
path: root/src/net/tcp.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/tcp.rs')
-rw-r--r--src/net/tcp.rs173
1 files changed, 173 insertions, 0 deletions
diff --git a/src/net/tcp.rs b/src/net/tcp.rs
new file mode 100644
index 0000000..94b876a
--- /dev/null
+++ b/src/net/tcp.rs
@@ -0,0 +1,173 @@
+use std::{
+ cell::RefCell,
+ io::{self, Read, Write},
+ net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream, ToSocketAddrs},
+ os::unix::prelude::AsRawFd,
+ rc::{Rc, Weak},
+ task::Poll,
+};
+
+use futures::Stream;
+use socket2::{Domain, Protocol, Socket, Type};
+
+use crate::{reactor::get_reactor, reactor::Reactor};
+
+/// TCP 监听器
+#[derive(Debug)]
+pub struct TcpListener {
+ reactor: Weak<RefCell<Reactor>>, // reactor
+ listener: StdTcpListener, // 标准库的 TcpListener | 包装一层
+}
+
+impl TcpListener {
+ /// 绑定地址并返回一个 `TcpListener` 实例
+ pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> {
+ // 解析地址
+ let addr = addr
+ .to_socket_addrs()?
+ .next()
+ .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?;
+
+ // 创建 socket
+ let domain = if addr.is_ipv6() {
+ Domain::IPV6
+ } else {
+ Domain::IPV4
+ };
+ let sk = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
+
+ // 绑定地址并监听
+ let addr = socket2::SockAddr::from(addr);
+ sk.set_reuse_address(true)?;
+ sk.bind(&addr)?;
+ sk.listen(1024)?;
+
+ // 将 fd 添加到 reactor 中
+ let reactor = get_reactor();
+ reactor.borrow_mut().add(sk.as_raw_fd());
+
+ println!("tcp bind with fd {}", sk.as_raw_fd());
+ Ok(Self {
+ reactor: Rc::downgrade(&reactor),
+ listener: sk.into(),
+ })
+ }
+}
+//Stream 流
+impl Stream for TcpListener { //TcpStream 和 TcpListener 在这链接
+ type Item = std::io::Result<(TcpStream, SocketAddr)>; //
+
+ fn poll_next(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ match self.listener.accept() {
+ Ok((stream, addr)) => Poll::Ready(Some(Ok((stream.into(), addr)))),
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { // 继续阻塞
+ // 修改反应器以注册感兴趣的事件
+ let reactor = self.reactor.upgrade().unwrap();
+ reactor
+ .borrow_mut()
+ .modify_readable(self.listener.as_raw_fd(), cx); // 可读事件
+ Poll::Pending
+ }
+ Err(e) => std::task::Poll::Ready(Some(Err(e))),
+ }
+ }
+}
+
+/// TCP 流
+#[derive(Debug)]
+pub struct TcpStream {
+ stream: StdTcpStream, // 标准库的 TcpStream | 包装一层
+}
+
+impl From<StdTcpStream> for TcpStream {
+ // 从标准库的 TcpStream 转换为自定义的 TcpStream
+ fn from(stream: StdTcpStream) -> Self {
+ let reactor: Rc<RefCell<Reactor>> = get_reactor(); // 获取 reactor
+ reactor.borrow_mut().add(stream.as_raw_fd()); // 将 fd 添加到 reactor
+ Self { stream } // 返回包装后的 TcpStream
+ }
+}
+
+impl Drop for TcpStream {
+ fn drop(&mut self) {
+ // 可变引用
+ println!("drop");
+ let reactor = get_reactor(); // 获取 reactor
+ reactor.borrow_mut().delete(self.stream.as_raw_fd()); // 将 fd 从 reactor 中删除
+ }
+}
+
+// 为 TcpStream 实现 tokio::io::AsyncRead
+impl tokio::io::AsyncRead for TcpStream {
+ fn poll_read(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ buf: &mut tokio::io::ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let fd = self.stream.as_raw_fd(); // 获取 stream 对应的 fd
+ unsafe {
+ // 将 ReadBuf 转换为 [u8] , stream.read 需要
+ let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]);
+ println!("read for fd {}", fd);
+ match self.stream.read(b) {
+ Ok(n) => { // 读取成功
+ println!("read for fd {} done, {}", fd, n);
+ buf.assume_init(n); // 初始化 n 个字节
+ buf.advance(n); // 指针前进 n 个字节
+ Poll::Ready(Ok(())) // 返回结果
+ }
+ Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 读取失败,且错误类型为 WouldBlock
+ println!("read for fd {} done WouldBlock", fd);
+ // 修改反应器以注册感兴趣的事件
+ let reactor = get_reactor(); // 获取 reactor
+ reactor
+ .borrow_mut()
+ .modify_readable(self.stream.as_raw_fd(), cx); // 注册到 reactor 可读事件
+ Poll::Pending // 等待
+ }
+ Err(e) => {
+ println!("read for fd {} done err", fd);
+ Poll::Ready(Err(e)) // 结束(错误)
+ }
+ }
+ }
+ }
+}
+
+impl tokio::io::AsyncWrite for TcpStream {
+ fn poll_write(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ match self.stream.write(buf) {
+ Ok(n) => Poll::Ready(Ok(n)), // 写入成功
+ Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 写入失败,且错误类型为 WouldBlock
+ let reactor: Rc<RefCell<Reactor>> = get_reactor();
+ reactor
+ .borrow_mut()
+ .modify_writable(self.stream.as_raw_fd(), cx); // 注册可写事件
+ Poll::Pending
+ }
+ Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息)
+ }
+ }
+
+ fn poll_flush( // 或许是仅占位吧.
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown( // 关闭时,将 stream 的写入关闭
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ self.stream.shutdown(std::net::Shutdown::Write)?; // 关闭写入,出错则传递错误
+ Poll::Ready(Ok(()))
+ }
+}