diff options
Diffstat (limited to 'src/net/udp.rs')
| -rw-r--r-- | src/net/udp.rs | 317 |
1 files changed, 317 insertions, 0 deletions
diff --git a/src/net/udp.rs b/src/net/udp.rs new file mode 100644 index 0000000..0a6a391 --- /dev/null +++ b/src/net/udp.rs @@ -0,0 +1,317 @@ +use std::cell::RefCell; +use std::io; +use std::net::{SocketAddr, ToSocketAddrs, UdpSocket as StdUdpSocket}; +use std::os::unix::prelude::AsRawFd; +use std::rc::{Rc, Weak}; +use std::task::Poll; + +use futures::{Future, Stream}; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; + +use crate::reactor::{get_reactor, Reactor}; + +/// UDP 套接字 +#[derive(Debug)] +pub struct UdpSocket { + reactor: Weak<RefCell<Reactor>>, // reactor + socket: StdUdpSocket, // 标准库的 UdpSocket | 包装一层 +} + +impl UdpSocket { + /// 绑定地址并返回一个 `UdpSocket` 实例 + pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> { + // 解析地址 + let addr = addr + .to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?; + + // 地址解析 + let domain = if addr.is_ipv6() { + Domain::IPV6 + } else { + Domain::IPV4 + }; + let sk = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + + // 绑定地址 + let addr: SockAddr = SockAddr::from(addr); + sk.bind(&addr)?; + + // 将 fd 添加到 reactor 中 + let reactor = get_reactor(); + reactor.borrow_mut().add(sk.as_raw_fd()); + + println!("udp bind with fd {}", sk.as_raw_fd()); + Ok(Self { + reactor: Rc::downgrade(&reactor), + socket: sk.into(), + }) + } + + pub fn recv_from_async<'a>( + &'a self, + buf: &'a mut [u8], + ) -> impl Future<Output = io::Result<(usize, SocketAddr)>> + 'a { + struct RecvFromFuture<'a> { + socket: &'a StdUdpSocket, + buf: &'a mut [u8], + } + impl<'a> Future for RecvFromFuture<'a> { + type Output = io::Result<(usize, SocketAddr)>; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Self::Output> { + match self.socket.recv_from(self.buf) { + Ok((n, addr)) => { + println!("recv_from {} {} bytes", addr, n); + Poll::Ready(Ok((n, addr))) + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // 读取失败,且错误类型为 WouldBlock + println!("read for fd {} done WouldBlock", self.socket.as_raw_fd()); + // 修改反应器以注册感兴趣的事件 + let reactor = get_reactor(); // 获取 reactor + reactor + .borrow_mut() + .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件 + Poll::Pending // 等待 + } + Err(e) => { + println!("read for fd {} done err", self.socket.as_raw_fd()); + Poll::Ready(Err(e)) // 结束(错误) + } + } + } + } + RecvFromFuture { + socket: &self.socket, + buf, + } + } + + pub fn send_to_async<'a>( + &'a self, + buf: &'a [u8], + addr: &'a SocketAddr, + ) -> impl Future<Output = io::Result<usize>> + 'a { + struct SendToFuture<'a> { + socket: &'a StdUdpSocket, + buf: &'a [u8], + addr: &'a SocketAddr, + } + impl<'a> Future for SendToFuture<'a> { + type Output = io::Result<usize>; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Self::Output> { + match self.socket.send_to(self.buf, self.addr) { + Ok(n) => Poll::Ready(Ok(n)), // 写入成功 + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // 写入失败,且错误类型为 WouldBlock + let reactor: Rc<RefCell<Reactor>> = get_reactor(); + reactor + .borrow_mut() + .modify_writable(self.socket.as_raw_fd(), cx); // 注册可写事件 + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息) + } + } + } + SendToFuture { + socket: &self.socket, + buf, + addr, + } + } +} + +// impl From<StdUdpSocket> for UdpSocket { +// fn from(socket: StdUdpSocket) -> Self { +// Self { +// reactor: Weak::new(), +// socket, +// } +// } +// } + +impl Drop for UdpSocket { + fn drop(&mut self) { + println!("drop udp socket"); + if let Some(reactor) = self.reactor.upgrade() { + // 从 reactor 中移除 + reactor.borrow_mut().delete(self.socket.as_raw_fd()); + } + } +} + +// impl Future for UdpSocket { +// type Output = io::Result<(usize, SocketAddr)>; + +// fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { +// let mut buf = [0; 1024]; +// match self.socket.recv_from(&mut buf) { +// Ok((n, addr)) => { +// println!("recv_from {} {} bytes", addr, n); +// Poll::Ready(Ok((n, addr))) +// } +// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { +// // 读取失败,且错误类型为 WouldBlock +// println!("read for fd {} done WouldBlock", self.socket.as_raw_fd()); +// // 修改反应器以注册感兴趣的事件 +// let reactor = get_reactor(); // 获取 reactor +// reactor +// .borrow_mut() +// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件 +// Poll::Pending // 等待 +// } +// Err(e) => { +// println!("read for fd {} done err", self.socket.as_raw_fd()); +// Poll::Ready(Err(e)) // 结束(错误) +// } +// } +// } +// } + +// impl Stream for UdpSocket { +// type Item = io::Result<(usize, SocketAddr)>; + +// fn poll_next( +// self: std::pin::Pin<&mut Self>, +// cx: &mut std::task::Context<'_>, +// ) -> std::task::Poll<Option<Self::Item>> { +// match self.socket.recv_from(&mut self.buf) { +// Ok((n, addr)) => { +// println!("recv_from {} {} bytes", addr, n); +// Poll::Ready(Some(Ok((n, addr)))) +// } +// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { +// // 读取失败,且错误类型为 WouldBlock +// println!("read for fd {} done WouldBlock", self.socket.as_raw_fd()); +// // 修改反应器以注册感兴趣的事件 +// let reactor = get_reactor(); // 获取 reactor +// reactor +// .borrow_mut() +// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件 +// Poll::Pending // 等待 +// } +// Err(e) => { +// println!("read for fd {} done err", self.socket.as_raw_fd()); +// Poll::Ready(Some(Err(e))) // 结束(错误) +// } +// } +// } +// } + +// impl UdpSocket{ +// pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { +// let mut buf = tokio::io::ReadBuf::new(buf); +// self.read(&mut buf).await?; +// Ok((buf.filled().len(), buf.filled().len())) +// } +// } + +// impl Future for UdpSocket { +// type Output = io::Result<(Vec<u8>, SocketAddr)>; + +// fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> { +// let mut buf = vec![0u8; 1024]; +// match self.socket.recv_from(&mut buf) { +// Ok((n, addr)) => { +// println!("recv_from {} {} bytes", addr, n); +// Poll::Ready(Ok((buf, addr))) +// } +// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { +// // 读取失败,且错误类型为 WouldBlock +// println!("read for fd {} done WouldBlock", self.socket.as_raw_fd()); +// // 修改反应器以注册感兴趣的事件 +// let reactor = get_reactor(); // 获取 reactor +// reactor +// .borrow_mut() +// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件 +// Poll::Pending // 等待 +// } +// Err(e) => { +// println!("read for fd {} done err", self.socket.as_raw_fd()); +// Poll::Ready(Err(e)) // 结束(错误) +// } +// } +// } +// } + +// // 为 UdpSocket 实现 tokio::io::AsyncWrite +// impl tokio::io::AsyncWrite for UdpSocket { +// fn poll_write( +// mut self: std::pin::Pin<&mut Self>, +// cx: &mut std::task::Context<'_>, +// buf: &[u8], +// ) -> Poll<Result<usize, io::Error>> { +// match self.send_to(buf, &"127.0.0.1:8080".parse().unwrap()) { +// Ok(n) => Poll::Ready(Ok(n)), // 写入成功 +// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { // 写入失败,且错误类型为 WouldBlock +// let reactor: Rc<RefCell<Reactor>> = get_reactor(); +// reactor +// .borrow_mut() +// .modify_writable(self.socket.as_raw_fd(), cx); // 注册可写事件 +// Poll::Pending +// } +// Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息) +// } +// } + +// fn poll_flush( +// self: std::pin::Pin<&mut Self>, +// cx: &mut std::task::Context<'_>, +// ) -> Poll<Result<(), io::Error>> { +// Poll::Ready(Ok(())) +// } + +// fn poll_shutdown( +// self: std::pin::Pin<&mut Self>, +// cx: &mut std::task::Context<'_>, +// ) -> Poll<Result<(), io::Error>> { +// Poll::Ready(Ok(())) +// } +// } + +// 为 UdpSocket 实现 tokio::io::AsyncRead +// impl tokio::io::AsyncRead for UdpSocket { +// fn poll_read( +// mut self: std::pin::Pin<&mut Self>, +// cx: &mut std::task::Context<'_>, +// buf: &mut tokio::io::ReadBuf<'_>, +// ) -> Poll<Result<(), std::io::Error>> { +// let fd: i32 = self.socket.as_raw_fd(); // 获取 socket 对应的 fd +// unsafe { +// // 将 ReadBuf 转换为 [u8] , socket.recv_from 需要 +// let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]); +// println!("read for fd {}", fd); +// match self.socket.recv_from(b) { +// Ok((size, addr)) => { +// // 读取成功 +// println!("read for fd {} done, {}", fd, size); +// Poll::Ready(Ok((size, addr))) +// } +// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { +// // 读取失败,且错误类型为 WouldBlock +// println!("read for fd {} done WouldBlock", fd); +// // 修改反应器以注册感兴趣的事件 +// let reactor = get_reactor(); // 获取 reactor +// reactor +// .borrow_mut() +// .modify_readable(self.socket.as_raw_fd(), cx); // 注册到 reactor 可读事件 +// Poll::Pending // 等待 +// } +// Err(e) => { +// println!("read for fd {} done err", fd); +// Poll::Ready(Err(e)) // 结束(错误) +// } +// } +// } +// } +// } |
