diff options
Diffstat (limited to 'src/net/tcp.rs')
| -rw-r--r-- | src/net/tcp.rs | 173 |
1 files changed, 173 insertions, 0 deletions
diff --git a/src/net/tcp.rs b/src/net/tcp.rs new file mode 100644 index 0000000..94b876a --- /dev/null +++ b/src/net/tcp.rs @@ -0,0 +1,173 @@ +use std::{ + cell::RefCell, + io::{self, Read, Write}, + net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream, ToSocketAddrs}, + os::unix::prelude::AsRawFd, + rc::{Rc, Weak}, + task::Poll, +}; + +use futures::Stream; +use socket2::{Domain, Protocol, Socket, Type}; + +use crate::{reactor::get_reactor, reactor::Reactor}; + +/// TCP 监听器 +#[derive(Debug)] +pub struct TcpListener { + reactor: Weak<RefCell<Reactor>>, // reactor + listener: StdTcpListener, // 标准库的 TcpListener | 包装一层 +} + +impl TcpListener { + /// 绑定地址并返回一个 `TcpListener` 实例 + pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> { + // 解析地址 + let addr = addr + .to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?; + + // 创建 socket + let domain = if addr.is_ipv6() { + Domain::IPV6 + } else { + Domain::IPV4 + }; + let sk = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + + // 绑定地址并监听 + let addr = socket2::SockAddr::from(addr); + sk.set_reuse_address(true)?; + sk.bind(&addr)?; + sk.listen(1024)?; + + // 将 fd 添加到 reactor 中 + let reactor = get_reactor(); + reactor.borrow_mut().add(sk.as_raw_fd()); + + println!("tcp bind with fd {}", sk.as_raw_fd()); + Ok(Self { + reactor: Rc::downgrade(&reactor), + listener: sk.into(), + }) + } +} +//Stream 流 +impl Stream for TcpListener { //TcpStream 和 TcpListener 在这链接 + type Item = std::io::Result<(TcpStream, SocketAddr)>; // + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Option<Self::Item>> { + match self.listener.accept() { + Ok((stream, addr)) => Poll::Ready(Some(Ok((stream.into(), addr)))), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { // 继续阻塞 + // 修改反应器以注册感兴趣的事件 + let reactor = self.reactor.upgrade().unwrap(); + reactor + .borrow_mut() + .modify_readable(self.listener.as_raw_fd(), cx); // 可读事件 + Poll::Pending + } + Err(e) => std::task::Poll::Ready(Some(Err(e))), + } + } +} + +/// TCP 流 +#[derive(Debug)] +pub struct TcpStream { + stream: StdTcpStream, // 标准库的 TcpStream | 包装一层 +} + +impl From<StdTcpStream> for TcpStream { + // 从标准库的 TcpStream 转换为自定义的 TcpStream + fn from(stream: StdTcpStream) -> Self { + let reactor: Rc<RefCell<Reactor>> = get_reactor(); // 获取 reactor + reactor.borrow_mut().add(stream.as_raw_fd()); // 将 fd 添加到 reactor + Self { stream } // 返回包装后的 TcpStream + } +} + +impl Drop for TcpStream { + fn drop(&mut self) { + // 可变引用 + println!("drop"); + let reactor = get_reactor(); // 获取 reactor + reactor.borrow_mut().delete(self.stream.as_raw_fd()); // 将 fd 从 reactor 中删除 + } +} + +// 为 TcpStream 实现 tokio::io::AsyncRead +impl tokio::io::AsyncRead for TcpStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let fd = self.stream.as_raw_fd(); // 获取 stream 对应的 fd + unsafe { + // 将 ReadBuf 转换为 [u8] , stream.read 需要 + let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]); + println!("read for fd {}", fd); + match self.stream.read(b) { + Ok(n) => { // 读取成功 + println!("read for fd {} done, {}", fd, n); + buf.assume_init(n); // 初始化 n 个字节 + buf.advance(n); // 指针前进 n 个字节 + Poll::Ready(Ok(())) // 返回结果 + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 读取失败,且错误类型为 WouldBlock + println!("read for fd {} done WouldBlock", fd); + // 修改反应器以注册感兴趣的事件 + let reactor = get_reactor(); // 获取 reactor + reactor + .borrow_mut() + .modify_readable(self.stream.as_raw_fd(), cx); // 注册到 reactor 可读事件 + Poll::Pending // 等待 + } + Err(e) => { + println!("read for fd {} done err", fd); + Poll::Ready(Err(e)) // 结束(错误) + } + } + } + } +} + +impl tokio::io::AsyncWrite for TcpStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + match self.stream.write(buf) { + Ok(n) => Poll::Ready(Ok(n)), // 写入成功 + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 写入失败,且错误类型为 WouldBlock + let reactor: Rc<RefCell<Reactor>> = get_reactor(); + reactor + .borrow_mut() + .modify_writable(self.stream.as_raw_fd(), cx); // 注册可写事件 + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息) + } + } + + fn poll_flush( // 或许是仅占位吧. + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( // 关闭时,将 stream 的写入关闭 + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<(), io::Error>> { + self.stream.shutdown(std::net::Shutdown::Write)?; // 关闭写入,出错则传递错误 + Poll::Ready(Ok(())) + } +} |
