diff options
Diffstat (limited to 'src/net')
| -rw-r--r-- | src/net/mod.rs | 2 | ||||
| -rw-r--r-- | src/net/tcp.rs | 173 | ||||
| -rw-r--r-- | src/net/udp.rs | 317 |
3 files changed, 492 insertions, 0 deletions
diff --git a/src/net/mod.rs b/src/net/mod.rs new file mode 100644 index 0000000..a6f6844 --- /dev/null +++ b/src/net/mod.rs @@ -0,0 +1,2 @@ +pub mod tcp; +pub mod udp;
\ No newline at end of file diff --git a/src/net/tcp.rs b/src/net/tcp.rs new file mode 100644 index 0000000..94b876a --- /dev/null +++ b/src/net/tcp.rs @@ -0,0 +1,173 @@ +use std::{ + cell::RefCell, + io::{self, Read, Write}, + net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream, ToSocketAddrs}, + os::unix::prelude::AsRawFd, + rc::{Rc, Weak}, + task::Poll, +}; + +use futures::Stream; +use socket2::{Domain, Protocol, Socket, Type}; + +use crate::{reactor::get_reactor, reactor::Reactor}; + +/// TCP 监听器 +#[derive(Debug)] +pub struct TcpListener { + reactor: Weak<RefCell<Reactor>>, // reactor + listener: StdTcpListener, // 标准库的 TcpListener | 包装一层 +} + +impl TcpListener { + /// 绑定地址并返回一个 `TcpListener` 实例 + pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> { + // 解析地址 + let addr = addr + .to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?; + + // 创建 socket + let domain = if addr.is_ipv6() { + Domain::IPV6 + } else { + Domain::IPV4 + }; + let sk = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + + // 绑定地址并监听 + let addr = socket2::SockAddr::from(addr); + sk.set_reuse_address(true)?; + sk.bind(&addr)?; + sk.listen(1024)?; + + // 将 fd 添加到 reactor 中 + let reactor = get_reactor(); + reactor.borrow_mut().add(sk.as_raw_fd()); + + println!("tcp bind with fd {}", sk.as_raw_fd()); + Ok(Self { + reactor: Rc::downgrade(&reactor), + listener: sk.into(), + }) + } +} +//Stream 流 +impl Stream for TcpListener { //TcpStream 和 TcpListener 在这链接 + type Item = std::io::Result<(TcpStream, SocketAddr)>; // + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Option<Self::Item>> { + match self.listener.accept() { + Ok((stream, addr)) => Poll::Ready(Some(Ok((stream.into(), addr)))), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { // 继续阻塞 + // 修改反应器以注册感兴趣的事件 + let reactor = self.reactor.upgrade().unwrap(); + reactor + .borrow_mut() + .modify_readable(self.listener.as_raw_fd(), cx); // 可读事件 + Poll::Pending + } + Err(e) => std::task::Poll::Ready(Some(Err(e))), + } + } +} + +/// TCP 流 +#[derive(Debug)] +pub struct TcpStream { + stream: StdTcpStream, // 标准库的 TcpStream | 包装一层 +} + +impl From<StdTcpStream> for TcpStream { + // 从标准库的 TcpStream 转换为自定义的 TcpStream + fn from(stream: StdTcpStream) -> Self { + let reactor: Rc<RefCell<Reactor>> = get_reactor(); // 获取 reactor + reactor.borrow_mut().add(stream.as_raw_fd()); // 将 fd 添加到 reactor + Self { stream } // 返回包装后的 TcpStream + } +} + +impl Drop for TcpStream { + fn drop(&mut self) { + // 可变引用 + println!("drop"); + let reactor = get_reactor(); // 获取 reactor + reactor.borrow_mut().delete(self.stream.as_raw_fd()); // 将 fd 从 reactor 中删除 + } +} + +// 为 TcpStream 实现 tokio::io::AsyncRead +impl tokio::io::AsyncRead for TcpStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let fd = self.stream.as_raw_fd(); // 获取 stream 对应的 fd + unsafe { + // 将 ReadBuf 转换为 [u8] , stream.read 需要 + let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]); + println!("read for fd {}", fd); + match self.stream.read(b) { + Ok(n) => { // 读取成功 + println!("read for fd {} done, {}", fd, n); + buf.assume_init(n); // 初始化 n 个字节 + buf.advance(n); // 指针前进 n 个字节 + Poll::Ready(Ok(())) // 返回结果 + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 读取失败,且错误类型为 WouldBlock + println!("read for fd {} done WouldBlock", fd); + // 修改反应器以注册感兴趣的事件 + let reactor = get_reactor(); // 获取 reactor + reactor + .borrow_mut() + .modify_readable(self.stream.as_raw_fd(), cx); // 注册到 reactor 可读事件 + Poll::Pending // 等待 + } + Err(e) => { + println!("read for fd {} done err", fd); + Poll::Ready(Err(e)) // 结束(错误) + } + } + } + } +} + +impl tokio::io::AsyncWrite for TcpStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + match self.stream.write(buf) { + Ok(n) => Poll::Ready(Ok(n)), // 写入成功 + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 写入失败,且错误类型为 WouldBlock + let reactor: Rc<RefCell<Reactor>> = get_reactor(); + reactor + .borrow_mut() + .modify_writable(self.stream.as_raw_fd(), cx); // 注册可写事件 + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息) + } + } + + fn poll_flush( // 或许是仅占位吧. + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( // 关闭时,将 stream 的写入关闭 + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<(), io::Error>> { + self.stream.shutdown(std::net::Shutdown::Write)?; // 关闭写入,出错则传递错误 + Poll::Ready(Ok(())) + } +} diff --git a/src/net/udp.rs b/src/net/udp.rs new file mode 100644 index 0000000..0a6a391 --- /dev/null +++ b/src/net/udp.rs @@ -0,0 +1,317 @@ +use std::cell::RefCell; +use std::io; +use std::net::{SocketAddr, ToSocketAddrs, UdpSocket as StdUdpSocket}; +use std::os::unix::prelude::AsRawFd; +use std::rc::{Rc, Weak}; +use std::task::Poll; + +use futures::{Future, Stream}; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; + +use crate::reactor::{get_reactor, Reactor}; + +/// UDP 套接字 +#[derive(Debug)] +pub struct UdpSocket { + reactor: Weak<RefCell<Reactor>>, // reactor + socket: StdUdpSocket, // 标准库的 UdpSocket | 包装一层 +} + +impl UdpSocket { + /// 绑定地址并返回一个 `UdpSocket` 实例 + pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> { + // 解析地址 + let addr = addr + .to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?; + + // 地址解析 + let domain = if addr.is_ipv6() { + Domain::IPV6 + } else { + Domain::IPV4 + }; + let sk = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + + // 绑定地址 + let addr: SockAddr = SockAddr::from(addr); + sk.bind(&addr)?; + + // 将 fd 添加到 reactor 中 + let reactor = get_reactor(); + reactor.borrow_mut().add(sk.as_raw_fd()); + + println!("udp bind with fd {}", sk.as_raw_fd()); + Ok(Self { + reactor: Rc::downgrade(&reactor), + socket: sk.into(), + }) + } + + pub fn recv_from_async<'a>( + &'a self, + buf: &'a mut [u8], + ) -> impl Future<Output = io::Result<(usize, SocketAddr)>> + 'a { + struct RecvFromFuture<'a> { + socket: &'a StdUdpSocket, + buf: &'a mut [u8], + } + impl<'a> Future for RecvFromFuture<'a> { + type Output = io::Result<(usize, SocketAddr)>; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Self::Output> { + match self.socket.recv_from(self.buf) { + Ok((n, addr)) => { + println!("recv_from {} {} bytes", addr, n); + Poll::Ready(Ok((n, addr))) + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // 读取失败,且错误类型为 WouldBlock + println!("read for fd {} done WouldBlock", self.socket.as_raw_fd()); + // 修改反应器以注册感兴趣的事件 + let reactor = get_reactor(); // 获取 reactor + reactor + .borrow_mut() + .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件 + Poll::Pending // 等待 + } + Err(e) => { + println!("read for fd {} done err", self.socket.as_raw_fd()); + Poll::Ready(Err(e)) // 结束(错误) + } + } + } + } + RecvFromFuture { + socket: &self.socket, + buf, + } + } + + pub fn send_to_async<'a>( + &'a self, + buf: &'a [u8], + addr: &'a SocketAddr, + ) -> impl Future<Output = io::Result<usize>> + 'a { + struct SendToFuture<'a> { + socket: &'a StdUdpSocket, + buf: &'a [u8], + addr: &'a SocketAddr, + } + impl<'a> Future for SendToFuture<'a> { + type Output = io::Result<usize>; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Self::Output> { + match self.socket.send_to(self.buf, self.addr) { + Ok(n) => Poll::Ready(Ok(n)), // 写入成功 + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // 写入失败,且错误类型为 WouldBlock + let reactor: Rc<RefCell<Reactor>> = get_reactor(); + reactor + .borrow_mut() + .modify_writable(self.socket.as_raw_fd(), cx); // 注册可写事件 + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息) + } + } + } + SendToFuture { + socket: &self.socket, + buf, + addr, + } + } +} + +// impl From<StdUdpSocket> for UdpSocket { +// fn from(socket: StdUdpSocket) -> Self { +// Self { +// reactor: Weak::new(), +// socket, +// } +// } +// } + +impl Drop for UdpSocket { + fn drop(&mut self) { + println!("drop udp socket"); + if let Some(reactor) = self.reactor.upgrade() { + // 从 reactor 中移除 + reactor.borrow_mut().delete(self.socket.as_raw_fd()); + } + } +} + +// impl Future for UdpSocket { +// type Output = io::Result<(usize, SocketAddr)>; + +// fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { +// let mut buf = [0; 1024]; +// match self.socket.recv_from(&mut buf) { +// Ok((n, addr)) => { +// println!("recv_from {} {} bytes", addr, n); +// Poll::Ready(Ok((n, addr))) +// } +// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { +// // 读取失败,且错误类型为 WouldBlock +// println!("read for fd {} done WouldBlock", self.socket.as_raw_fd()); +// // 修改反应器以注册感兴趣的事件 +// let reactor = get_reactor(); // 获取 reactor +// reactor +// .borrow_mut() +// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件 +// Poll::Pending // 等待 +// } +// Err(e) => { +// println!("read for fd {} done err", self.socket.as_raw_fd()); +// Poll::Ready(Err(e)) // 结束(错误) +// } +// } +// } +// } + +// impl Stream for UdpSocket { +// type Item = io::Result<(usize, SocketAddr)>; + +// fn poll_next( +// self: std::pin::Pin<&mut Self>, +// cx: &mut std::task::Context<'_>, +// ) -> std::task::Poll<Option<Self::Item>> { +// match self.socket.recv_from(&mut self.buf) { +// Ok((n, addr)) => { +// println!("recv_from {} {} bytes", addr, n); +// Poll::Ready(Some(Ok((n, addr)))) +// } +// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { +// // 读取失败,且错误类型为 WouldBlock +// println!("read for fd {} done WouldBlock", self.socket.as_raw_fd()); +// // 修改反应器以注册感兴趣的事件 +// let reactor = get_reactor(); // 获取 reactor +// reactor +// .borrow_mut() +// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件 +// Poll::Pending // 等待 +// } +// Err(e) => { +// println!("read for fd {} done err", self.socket.as_raw_fd()); +// Poll::Ready(Some(Err(e))) // 结束(错误) +// } +// } +// } +// } + +// impl UdpSocket{ +// pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { +// let mut buf = tokio::io::ReadBuf::new(buf); +// self.read(&mut buf).await?; +// Ok((buf.filled().len(), buf.filled().len())) +// } +// } + +// impl Future for UdpSocket { +// type Output = io::Result<(Vec<u8>, SocketAddr)>; + +// fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { +// let mut buf = vec![0u8; 1024]; +// match self.socket.recv_from(&mut buf) { +// Ok((n, addr)) => { +// println!("recv_from {} {} bytes", addr, n); +// Poll::Ready(Ok((buf, addr))) +// } +// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { +// // 读取失败,且错误类型为 WouldBlock +// println!("read for fd {} done WouldBlock", self.socket.as_raw_fd()); +// // 修改反应器以注册感兴趣的事件 +// let reactor = get_reactor(); // 获取 reactor +// reactor +// .borrow_mut() +// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件 +// Poll::Pending // 等待 +// } +// Err(e) => { +// println!("read for fd {} done err", self.socket.as_raw_fd()); +// Poll::Ready(Err(e)) // 结束(错误) +// } +// } +// } +// } + +// // 为 UdpSocket 实现 tokio::io::AsyncWrite +// impl tokio::io::AsyncWrite for UdpSocket { +// fn poll_write( +// mut self: std::pin::Pin<&mut Self>, +// cx: &mut std::task::Context<'_>, +// buf: &[u8], +// ) -> Poll<Result<usize, io::Error>> { +// match self.send_to(buf, &"127.0.0.1:8080".parse().unwrap()) { +// Ok(n) => Poll::Ready(Ok(n)), // 写入成功 +// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { // 写入失败,且错误类型为 WouldBlock +// let reactor: Rc<RefCell<Reactor>> = get_reactor(); +// reactor +// .borrow_mut() +// .modify_writable(self.socket.as_raw_fd(), cx); // 注册可写事件 +// Poll::Pending +// } +// Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息) +// } +// } + +// fn poll_flush( +// self: std::pin::Pin<&mut Self>, +// cx: &mut std::task::Context<'_>, +// ) -> Poll<Result<(), io::Error>> { +// Poll::Ready(Ok(())) +// } + +// fn poll_shutdown( +// self: std::pin::Pin<&mut Self>, +// cx: &mut std::task::Context<'_>, +// ) -> Poll<Result<(), io::Error>> { +// Poll::Ready(Ok(())) +// } +// } + +// 为 UdpSocket 实现 tokio::io::AsyncRead +// impl tokio::io::AsyncRead for UdpSocket { +// fn poll_read( +// mut self: std::pin::Pin<&mut Self>, +// cx: &mut std::task::Context<'_>, +// buf: &mut tokio::io::ReadBuf<'_>, +// ) -> Poll<Result<(), std::io::Error>> { +// let fd: i32 = self.socket.as_raw_fd(); // 获取 socket 对应的 fd +// unsafe { +// // 将 ReadBuf 转换为 [u8] , socket.recv_from 需要 +// let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]); +// println!("read for fd {}", fd); +// match self.socket.recv_from(b) { +// Ok((size, addr)) => { +// // 读取成功 +// println!("read for fd {} done, {}", fd, size); +// Poll::Ready(Ok((size, addr))) +// } +// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { +// // 读取失败,且错误类型为 WouldBlock +// println!("read for fd {} done WouldBlock", fd); +// // 修改反应器以注册感兴趣的事件 +// let reactor = get_reactor(); // 获取 reactor +// reactor +// .borrow_mut() +// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件 +// Poll::Pending // 等待 +// } +// Err(e) => { +// println!("read for fd {} done err", fd); +// Poll::Ready(Err(e)) // 结束(错误) +// } +// } +// } +// } +// } |
