summaryrefslogtreecommitdiff
path: root/src/net
diff options
context:
space:
mode:
Diffstat (limited to 'src/net')
-rw-r--r--src/net/mod.rs2
-rw-r--r--src/net/tcp.rs173
-rw-r--r--src/net/udp.rs317
3 files changed, 492 insertions, 0 deletions
diff --git a/src/net/mod.rs b/src/net/mod.rs
new file mode 100644
index 0000000..a6f6844
--- /dev/null
+++ b/src/net/mod.rs
@@ -0,0 +1,2 @@
+pub mod tcp;
+pub mod udp; \ No newline at end of file
diff --git a/src/net/tcp.rs b/src/net/tcp.rs
new file mode 100644
index 0000000..94b876a
--- /dev/null
+++ b/src/net/tcp.rs
@@ -0,0 +1,173 @@
+use std::{
+ cell::RefCell,
+ io::{self, Read, Write},
+ net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream, ToSocketAddrs},
+ os::unix::prelude::AsRawFd,
+ rc::{Rc, Weak},
+ task::Poll,
+};
+
+use futures::Stream;
+use socket2::{Domain, Protocol, Socket, Type};
+
+use crate::{reactor::get_reactor, reactor::Reactor};
+
+/// TCP 监听器
+#[derive(Debug)]
+pub struct TcpListener {
+ reactor: Weak<RefCell<Reactor>>, // reactor
+ listener: StdTcpListener, // 标准库的 TcpListener | 包装一层
+}
+
+impl TcpListener {
+ /// 绑定地址并返回一个 `TcpListener` 实例
+ pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> {
+ // 解析地址
+ let addr = addr
+ .to_socket_addrs()?
+ .next()
+ .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?;
+
+ // 创建 socket
+ let domain = if addr.is_ipv6() {
+ Domain::IPV6
+ } else {
+ Domain::IPV4
+ };
+ let sk = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
+
+ // 绑定地址并监听
+ let addr = socket2::SockAddr::from(addr);
+ sk.set_reuse_address(true)?;
+ sk.bind(&addr)?;
+ sk.listen(1024)?;
+
+ // 将 fd 添加到 reactor 中
+ let reactor = get_reactor();
+ reactor.borrow_mut().add(sk.as_raw_fd());
+
+ println!("tcp bind with fd {}", sk.as_raw_fd());
+ Ok(Self {
+ reactor: Rc::downgrade(&reactor),
+ listener: sk.into(),
+ })
+ }
+}
+//Stream 流
+impl Stream for TcpListener { //TcpStream 和 TcpListener 在这链接
+ type Item = std::io::Result<(TcpStream, SocketAddr)>; //
+
+ fn poll_next(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ match self.listener.accept() {
+ Ok((stream, addr)) => Poll::Ready(Some(Ok((stream.into(), addr)))),
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { // 继续阻塞
+ // 修改反应器以注册感兴趣的事件
+ let reactor = self.reactor.upgrade().unwrap();
+ reactor
+ .borrow_mut()
+ .modify_readable(self.listener.as_raw_fd(), cx); // 可读事件
+ Poll::Pending
+ }
+ Err(e) => std::task::Poll::Ready(Some(Err(e))),
+ }
+ }
+}
+
+/// TCP 流
+#[derive(Debug)]
+pub struct TcpStream {
+ stream: StdTcpStream, // 标准库的 TcpStream | 包装一层
+}
+
+impl From<StdTcpStream> for TcpStream {
+ // 从标准库的 TcpStream 转换为自定义的 TcpStream
+ fn from(stream: StdTcpStream) -> Self {
+ let reactor: Rc<RefCell<Reactor>> = get_reactor(); // 获取 reactor
+ reactor.borrow_mut().add(stream.as_raw_fd()); // 将 fd 添加到 reactor
+ Self { stream } // 返回包装后的 TcpStream
+ }
+}
+
+impl Drop for TcpStream {
+ fn drop(&mut self) {
+ // 可变引用
+ println!("drop");
+ let reactor = get_reactor(); // 获取 reactor
+ reactor.borrow_mut().delete(self.stream.as_raw_fd()); // 将 fd 从 reactor 中删除
+ }
+}
+
+// 为 TcpStream 实现 tokio::io::AsyncRead
+impl tokio::io::AsyncRead for TcpStream {
+ fn poll_read(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ buf: &mut tokio::io::ReadBuf<'_>,
+ ) -> Poll<io::Result<()>> {
+ let fd = self.stream.as_raw_fd(); // 获取 stream 对应的 fd
+ unsafe {
+ // 将 ReadBuf 转换为 [u8] , stream.read 需要
+ let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]);
+ println!("read for fd {}", fd);
+ match self.stream.read(b) {
+ Ok(n) => { // 读取成功
+ println!("read for fd {} done, {}", fd, n);
+ buf.assume_init(n); // 初始化 n 个字节
+ buf.advance(n); // 指针前进 n 个字节
+ Poll::Ready(Ok(())) // 返回结果
+ }
+ Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 读取失败,且错误类型为 WouldBlock
+ println!("read for fd {} done WouldBlock", fd);
+ // 修改反应器以注册感兴趣的事件
+ let reactor = get_reactor(); // 获取 reactor
+ reactor
+ .borrow_mut()
+ .modify_readable(self.stream.as_raw_fd(), cx); // 注册到 reactor 可读事件
+ Poll::Pending // 等待
+ }
+ Err(e) => {
+ println!("read for fd {} done err", fd);
+ Poll::Ready(Err(e)) // 结束(错误)
+ }
+ }
+ }
+ }
+}
+
+impl tokio::io::AsyncWrite for TcpStream {
+ fn poll_write(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ buf: &[u8],
+ ) -> Poll<Result<usize, io::Error>> {
+ match self.stream.write(buf) {
+ Ok(n) => Poll::Ready(Ok(n)), // 写入成功
+ Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 写入失败,且错误类型为 WouldBlock
+ let reactor: Rc<RefCell<Reactor>> = get_reactor();
+ reactor
+ .borrow_mut()
+ .modify_writable(self.stream.as_raw_fd(), cx); // 注册可写事件
+ Poll::Pending
+ }
+ Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息)
+ }
+ }
+
+ fn poll_flush( // 或许是仅占位吧.
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ Poll::Ready(Ok(()))
+ }
+
+ fn poll_shutdown( // 关闭时,将 stream 的写入关闭
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Result<(), io::Error>> {
+ self.stream.shutdown(std::net::Shutdown::Write)?; // 关闭写入,出错则传递错误
+ Poll::Ready(Ok(()))
+ }
+}
diff --git a/src/net/udp.rs b/src/net/udp.rs
new file mode 100644
index 0000000..0a6a391
--- /dev/null
+++ b/src/net/udp.rs
@@ -0,0 +1,317 @@
+use std::cell::RefCell;
+use std::io;
+use std::net::{SocketAddr, ToSocketAddrs, UdpSocket as StdUdpSocket};
+use std::os::unix::prelude::AsRawFd;
+use std::rc::{Rc, Weak};
+use std::task::Poll;
+
+use futures::{Future, Stream};
+use socket2::{Domain, Protocol, SockAddr, Socket, Type};
+
+use crate::reactor::{get_reactor, Reactor};
+
+/// UDP 套接字
+#[derive(Debug)]
+pub struct UdpSocket {
+ reactor: Weak<RefCell<Reactor>>, // reactor
+ socket: StdUdpSocket, // 标准库的 UdpSocket | 包装一层
+}
+
+impl UdpSocket {
+ /// 绑定地址并返回一个 `UdpSocket` 实例
+ pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> {
+ // 解析地址
+ let addr = addr
+ .to_socket_addrs()?
+ .next()
+ .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?;
+
+ // 地址解析
+ let domain = if addr.is_ipv6() {
+ Domain::IPV6
+ } else {
+ Domain::IPV4
+ };
+ let sk = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
+
+ // 绑定地址
+ let addr: SockAddr = SockAddr::from(addr);
+ sk.bind(&addr)?;
+
+ // 将 fd 添加到 reactor 中
+ let reactor = get_reactor();
+ reactor.borrow_mut().add(sk.as_raw_fd());
+
+ println!("udp bind with fd {}", sk.as_raw_fd());
+ Ok(Self {
+ reactor: Rc::downgrade(&reactor),
+ socket: sk.into(),
+ })
+ }
+
+ pub fn recv_from_async<'a>(
+ &'a self,
+ buf: &'a mut [u8],
+ ) -> impl Future<Output = io::Result<(usize, SocketAddr)>> + 'a {
+ struct RecvFromFuture<'a> {
+ socket: &'a StdUdpSocket,
+ buf: &'a mut [u8],
+ }
+ impl<'a> Future for RecvFromFuture<'a> {
+ type Output = io::Result<(usize, SocketAddr)>;
+
+ fn poll(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Self::Output> {
+ match self.socket.recv_from(self.buf) {
+ Ok((n, addr)) => {
+ println!("recv_from {} {} bytes", addr, n);
+ Poll::Ready(Ok((n, addr)))
+ }
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ // 读取失败,且错误类型为 WouldBlock
+ println!("read for fd {} done WouldBlock", self.socket.as_raw_fd());
+ // 修改反应器以注册感兴趣的事件
+ let reactor = get_reactor(); // 获取 reactor
+ reactor
+ .borrow_mut()
+ .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件
+ Poll::Pending // 等待
+ }
+ Err(e) => {
+ println!("read for fd {} done err", self.socket.as_raw_fd());
+ Poll::Ready(Err(e)) // 结束(错误)
+ }
+ }
+ }
+ }
+ RecvFromFuture {
+ socket: &self.socket,
+ buf,
+ }
+ }
+
+ pub fn send_to_async<'a>(
+ &'a self,
+ buf: &'a [u8],
+ addr: &'a SocketAddr,
+ ) -> impl Future<Output = io::Result<usize>> + 'a {
+ struct SendToFuture<'a> {
+ socket: &'a StdUdpSocket,
+ buf: &'a [u8],
+ addr: &'a SocketAddr,
+ }
+ impl<'a> Future for SendToFuture<'a> {
+ type Output = io::Result<usize>;
+
+ fn poll(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Self::Output> {
+ match self.socket.send_to(self.buf, self.addr) {
+ Ok(n) => Poll::Ready(Ok(n)), // 写入成功
+ Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+ // 写入失败,且错误类型为 WouldBlock
+ let reactor: Rc<RefCell<Reactor>> = get_reactor();
+ reactor
+ .borrow_mut()
+ .modify_writable(self.socket.as_raw_fd(), cx); // 注册可写事件
+ Poll::Pending
+ }
+ Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息)
+ }
+ }
+ }
+ SendToFuture {
+ socket: &self.socket,
+ buf,
+ addr,
+ }
+ }
+}
+
+// impl From<StdUdpSocket> for UdpSocket {
+// fn from(socket: StdUdpSocket) -> Self {
+// Self {
+// reactor: Weak::new(),
+// socket,
+// }
+// }
+// }
+
+impl Drop for UdpSocket {
+ fn drop(&mut self) {
+ println!("drop udp socket");
+ if let Some(reactor) = self.reactor.upgrade() {
+ // 从 reactor 中移除
+ reactor.borrow_mut().delete(self.socket.as_raw_fd());
+ }
+ }
+}
+
+// impl Future for UdpSocket {
+// type Output = io::Result<(usize, SocketAddr)>;
+
+// fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
+// let mut buf = [0; 1024];
+// match self.socket.recv_from(&mut buf) {
+// Ok((n, addr)) => {
+// println!("recv_from {} {} bytes", addr, n);
+// Poll::Ready(Ok((n, addr)))
+// }
+// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+// // 读取失败,且错误类型为 WouldBlock
+// println!("read for fd {} done WouldBlock", self.socket.as_raw_fd());
+// // 修改反应器以注册感兴趣的事件
+// let reactor = get_reactor(); // 获取 reactor
+// reactor
+// .borrow_mut()
+// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件
+// Poll::Pending // 等待
+// }
+// Err(e) => {
+// println!("read for fd {} done err", self.socket.as_raw_fd());
+// Poll::Ready(Err(e)) // 结束(错误)
+// }
+// }
+// }
+// }
+
+// impl Stream for UdpSocket {
+// type Item = io::Result<(usize, SocketAddr)>;
+
+// fn poll_next(
+// self: std::pin::Pin<&mut Self>,
+// cx: &mut std::task::Context<'_>,
+// ) -> std::task::Poll<Option<Self::Item>> {
+// match self.socket.recv_from(&mut self.buf) {
+// Ok((n, addr)) => {
+// println!("recv_from {} {} bytes", addr, n);
+// Poll::Ready(Some(Ok((n, addr))))
+// }
+// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+// // 读取失败,且错误类型为 WouldBlock
+// println!("read for fd {} done WouldBlock", self.socket.as_raw_fd());
+// // 修改反应器以注册感兴趣的事件
+// let reactor = get_reactor(); // 获取 reactor
+// reactor
+// .borrow_mut()
+// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件
+// Poll::Pending // 等待
+// }
+// Err(e) => {
+// println!("read for fd {} done err", self.socket.as_raw_fd());
+// Poll::Ready(Some(Err(e))) // 结束(错误)
+// }
+// }
+// }
+// }
+
+// impl UdpSocket{
+// pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
+// let mut buf = tokio::io::ReadBuf::new(buf);
+// self.read(&mut buf).await?;
+// Ok((buf.filled().len(), buf.filled().len()))
+// }
+// }
+
+// impl Future for UdpSocket {
+// type Output = io::Result<(Vec<u8>, SocketAddr)>;
+
+// fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
+// let mut buf = vec![0u8; 1024];
+// match self.socket.recv_from(&mut buf) {
+// Ok((n, addr)) => {
+// println!("recv_from {} {} bytes", addr, n);
+// Poll::Ready(Ok((buf, addr)))
+// }
+// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+// // 读取失败,且错误类型为 WouldBlock
+// println!("read for fd {} done WouldBlock", self.socket.as_raw_fd());
+// // 修改反应器以注册感兴趣的事件
+// let reactor = get_reactor(); // 获取 reactor
+// reactor
+// .borrow_mut()
+// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件
+// Poll::Pending // 等待
+// }
+// Err(e) => {
+// println!("read for fd {} done err", self.socket.as_raw_fd());
+// Poll::Ready(Err(e)) // 结束(错误)
+// }
+// }
+// }
+// }
+
+// // 为 UdpSocket 实现 tokio::io::AsyncWrite
+// impl tokio::io::AsyncWrite for UdpSocket {
+// fn poll_write(
+// mut self: std::pin::Pin<&mut Self>,
+// cx: &mut std::task::Context<'_>,
+// buf: &[u8],
+// ) -> Poll<Result<usize, io::Error>> {
+// match self.send_to(buf, &"127.0.0.1:8080".parse().unwrap()) {
+// Ok(n) => Poll::Ready(Ok(n)), // 写入成功
+// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { // 写入失败,且错误类型为 WouldBlock
+// let reactor: Rc<RefCell<Reactor>> = get_reactor();
+// reactor
+// .borrow_mut()
+// .modify_writable(self.socket.as_raw_fd(), cx); // 注册可写事件
+// Poll::Pending
+// }
+// Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息)
+// }
+// }
+
+// fn poll_flush(
+// self: std::pin::Pin<&mut Self>,
+// cx: &mut std::task::Context<'_>,
+// ) -> Poll<Result<(), io::Error>> {
+// Poll::Ready(Ok(()))
+// }
+
+// fn poll_shutdown(
+// self: std::pin::Pin<&mut Self>,
+// cx: &mut std::task::Context<'_>,
+// ) -> Poll<Result<(), io::Error>> {
+// Poll::Ready(Ok(()))
+// }
+// }
+
+// 为 UdpSocket 实现 tokio::io::AsyncRead
+// impl tokio::io::AsyncRead for UdpSocket {
+// fn poll_read(
+// mut self: std::pin::Pin<&mut Self>,
+// cx: &mut std::task::Context<'_>,
+// buf: &mut tokio::io::ReadBuf<'_>,
+// ) -> Poll<Result<(), std::io::Error>> {
+// let fd: i32 = self.socket.as_raw_fd(); // 获取 socket 对应的 fd
+// unsafe {
+// // 将 ReadBuf 转换为 [u8] , socket.recv_from 需要
+// let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]);
+// println!("read for fd {}", fd);
+// match self.socket.recv_from(b) {
+// Ok((size, addr)) => {
+// // 读取成功
+// println!("read for fd {} done, {}", fd, size);
+// Poll::Ready(Ok((size, addr)))
+// }
+// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
+// // 读取失败,且错误类型为 WouldBlock
+// println!("read for fd {} done WouldBlock", fd);
+// // 修改反应器以注册感兴趣的事件
+// let reactor = get_reactor(); // 获取 reactor
+// reactor
+// .borrow_mut()
+// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件
+// Poll::Pending // 等待
+// }
+// Err(e) => {
+// println!("read for fd {} done err", fd);
+// Poll::Ready(Err(e)) // 结束(错误)
+// }
+// }
+// }
+// }
+// }