diff options
| author | ihc童鞋@提不起劲 <[email protected]> | 2023-07-03 15:16:00 +0800 |
|---|---|---|
| committer | GitHub <[email protected]> | 2023-07-03 15:16:00 +0800 |
| commit | 50147fad39f1e68546a12c9a6b4066aed4f44cee (patch) | |
| tree | cbba5c29a6969eebf0f8cbb88df6d00d56e401cc | |
| parent | aba979c378d72380a0deb59c7b7454d2dfc0bbfa (diff) | |
feat: support unix datagram (#183)
* feat: support unix datagram
* fix: clippy
| -rw-r--r-- | monoio-compat/src/tcp_unsafe.rs | 2 | ||||
| -rw-r--r-- | monoio/src/driver/op/recv.rs | 80 | ||||
| -rw-r--r-- | monoio/src/driver/op/send.rs | 73 | ||||
| -rw-r--r-- | monoio/src/macros/select.rs | 13 | ||||
| -rw-r--r-- | monoio/src/net/tcp/listener.rs | 4 | ||||
| -rw-r--r-- | monoio/src/net/tcp/stream.rs | 2 | ||||
| -rw-r--r-- | monoio/src/net/unix/datagram/mod.rs | 43 | ||||
| -rw-r--r-- | monoio/src/net/unix/socket_addr.rs | 32 | ||||
| -rw-r--r-- | monoio/src/net/unix/stream.rs | 2 | ||||
| -rw-r--r-- | monoio/src/task/raw.rs | 2 | ||||
| -rw-r--r-- | monoio/src/time/driver/mod.rs | 4 | ||||
| -rw-r--r-- | monoio/tests/unix_datagram.rs | 27 |
12 files changed, 265 insertions, 19 deletions
diff --git a/monoio-compat/src/tcp_unsafe.rs b/monoio-compat/src/tcp_unsafe.rs index 7f9e0cb..30dfca8 100644 --- a/monoio-compat/src/tcp_unsafe.rs +++ b/monoio-compat/src/tcp_unsafe.rs @@ -108,7 +108,7 @@ impl tokio::io::AsyncWrite for TcpStreamCompat { buf: &[u8], ) -> std::task::Poll<Result<usize, std::io::Error>> { let this = self.get_mut(); - let (ptr, len) = (buf.as_ptr() as *const u8, buf.len()); + let (ptr, len) = (buf.as_ptr(), buf.len()); // Set or check write_dst // Note: the check can not prevent memory crash when user misuse it. diff --git a/monoio/src/driver/op/recv.rs b/monoio/src/driver/op/recv.rs index 3f8885f..ce6ecc8 100644 --- a/monoio/src/driver/op/recv.rs +++ b/monoio/src/driver/op/recv.rs @@ -13,7 +13,7 @@ use { }; use super::{super::shared_fd::SharedFd, Op, OpAble}; -use crate::{buf::IoBufMut, BufResult}; +use crate::{buf::IoBufMut, net::unix::SocketAddr as UnixSocketAddr, BufResult}; pub(crate) struct Recv<T> { /// Holds a strong ref to the FD, preventing the file from being closed @@ -181,3 +181,81 @@ impl<T: IoBufMut> OpAble for RecvMsg<T> { syscall_u32!(recvmsg(fd, &mut self.info.2 as *mut _, 0)) } } + +pub(crate) struct RecvMsgUnix<T> { + /// Holds a strong ref to the FD, preventing the file from being closed + /// while the operation is in-flight. + #[allow(unused)] + fd: SharedFd, + + /// Reference to the in-flight buffer. + pub(crate) buf: T, + pub(crate) info: Box<( + MaybeUninit<libc::sockaddr_storage>, + [libc::iovec; 1], + libc::msghdr, + )>, +} + +impl<T: IoBufMut> Op<RecvMsgUnix<T>> { + pub(crate) fn recv_msg_unix(fd: SharedFd, mut buf: T) -> io::Result<Self> { + let iovec = [libc::iovec { + iov_base: buf.write_ptr() as *mut _, + iov_len: buf.bytes_total(), + }]; + let mut info: Box<( + MaybeUninit<libc::sockaddr_storage>, + [libc::iovec; 1], + libc::msghdr, + )> = Box::new((MaybeUninit::uninit(), iovec, unsafe { std::mem::zeroed() })); + + info.2.msg_iov = info.1.as_mut_ptr(); + info.2.msg_iovlen = 1; + info.2.msg_name = &mut info.0 as *mut _ as *mut libc::c_void; + info.2.msg_namelen = std::mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t; + + Op::submit_with(RecvMsgUnix { fd, buf, info }) + } + + pub(crate) async fn wait(self) -> BufResult<(usize, UnixSocketAddr), T> { + let complete = self.await; + let res = complete.meta.result.map(|v| v as _); + let mut buf = complete.data.buf; + + let res = res.map(|n| { + let storage = unsafe { complete.data.info.0.assume_init() }; + let name_len = complete.data.info.2.msg_namelen; + + let addr = unsafe { + let addr: &libc::sockaddr_un = transmute(&storage); + UnixSocketAddr::from_parts(*addr, name_len) + }; + + // Safety: the kernel wrote `n` bytes to the buffer. + unsafe { + buf.set_init(n); + } + + (n, addr) + }); + (res, buf) + } +} + +impl<T: IoBufMut> OpAble for RecvMsgUnix<T> { + #[cfg(all(target_os = "linux", feature = "iouring"))] + fn uring_op(&mut self) -> io_uring::squeue::Entry { + opcode::RecvMsg::new(types::Fd(self.fd.raw_fd()), &mut self.info.2 as *mut _).build() + } + + #[cfg(all(unix, feature = "legacy"))] + fn legacy_interest(&self) -> Option<(Direction, usize)> { + self.fd.registered_index().map(|idx| (Direction::Read, idx)) + } + + #[cfg(all(unix, feature = "legacy"))] + fn legacy_call(&mut self) -> io::Result<u32> { + let fd = self.fd.as_raw_fd(); + syscall_u32!(recvmsg(fd, &mut self.info.2 as *mut _, 0)) + } +} diff --git a/monoio/src/driver/op/send.rs b/monoio/src/driver/op/send.rs index bd7222b..0f3aeed 100644 --- a/monoio/src/driver/op/send.rs +++ b/monoio/src/driver/op/send.rs @@ -10,7 +10,7 @@ use { }; use super::{super::shared_fd::SharedFd, Op, OpAble}; -use crate::{buf::IoBuf, BufResult}; +use crate::{buf::IoBuf, net::unix::SocketAddr as UnixSocketAddr, BufResult}; pub(crate) struct Send<T> { /// Holds a strong ref to the FD, preventing the file from being closed @@ -171,3 +171,74 @@ impl<T: IoBuf> OpAble for SendMsg<T> { syscall_u32!(sendmsg(fd, &mut self.info.2 as *mut _, 0)) } } + +pub(crate) struct SendMsgUnix<T> { + /// Holds a strong ref to the FD, preventing the file from being closed + /// while the operation is in-flight. + #[allow(unused)] + fd: SharedFd, + + /// Reference to the in-flight buffer. + pub(crate) buf: T, + pub(crate) info: Box<(Option<UnixSocketAddr>, [libc::iovec; 1], libc::msghdr)>, +} + +impl<T: IoBuf> Op<SendMsgUnix<T>> { + pub(crate) fn send_msg_unix( + fd: SharedFd, + buf: T, + socket_addr: Option<UnixSocketAddr>, + ) -> io::Result<Self> { + let iovec = [libc::iovec { + iov_base: buf.read_ptr() as *const _ as *mut _, + iov_len: buf.bytes_init(), + }]; + let mut info: Box<(Option<UnixSocketAddr>, [libc::iovec; 1], libc::msghdr)> = + Box::new((socket_addr.map(Into::into), iovec, unsafe { + std::mem::zeroed() + })); + + info.2.msg_iov = info.1.as_mut_ptr(); + info.2.msg_iovlen = 1; + + match info.0.as_ref() { + Some(socket_addr) => { + info.2.msg_name = socket_addr.as_ptr() as *mut libc::c_void; + info.2.msg_namelen = socket_addr.len(); + } + None => { + info.2.msg_name = std::ptr::null_mut(); + info.2.msg_namelen = 0; + } + } + + Op::submit_with(SendMsgUnix { fd, buf, info }) + } + + pub(crate) async fn wait(self) -> BufResult<usize, T> { + let complete = self.await; + let res = complete.meta.result.map(|v| v as _); + let buf = complete.data.buf; + (res, buf) + } +} + +impl<T: IoBuf> OpAble for SendMsgUnix<T> { + #[cfg(all(target_os = "linux", feature = "iouring"))] + fn uring_op(&mut self) -> io_uring::squeue::Entry { + opcode::SendMsg::new(types::Fd(self.fd.raw_fd()), &mut self.info.2 as *mut _).build() + } + + #[cfg(all(unix, feature = "legacy"))] + fn legacy_interest(&self) -> Option<(Direction, usize)> { + self.fd + .registered_index() + .map(|idx| (Direction::Write, idx)) + } + + #[cfg(all(unix, feature = "legacy"))] + fn legacy_call(&mut self) -> io::Result<u32> { + let fd = self.fd.as_raw_fd(); + syscall_u32!(sendmsg(fd, &mut self.info.2 as *mut _, 0)) + } +} diff --git a/monoio/src/macros/select.rs b/monoio/src/macros/select.rs index 3cddd12..8736ac8 100644 --- a/monoio/src/macros/select.rs +++ b/monoio/src/macros/select.rs @@ -30,13 +30,12 @@ /// /// The complete lifecycle of a `select!` expression is as follows: /// -/// 1. Evaluate all provided `<precondition>` expressions. If the precondition -/// returns `false`, disable the branch for the remainder of the current call -/// to `select!`. Re-entering `select!` due to a loop clears the "disabled" -/// state. -/// 2. Aggregate the `<async expression>`s from each branch, including the -/// disabled ones. If the branch is disabled, `<async expression>` is still -/// evaluated, but the resulting future is not polled. +/// 1. Evaluate all provided `<precondition>` expressions. If the precondition returns `false`, +/// disable the branch for the remainder of the current call to `select!`. Re-entering `select!` +/// due to a loop clears the "disabled" state. +/// 2. Aggregate the `<async expression>`s from each branch, including the disabled ones. If the +/// branch is disabled, `<async expression>` is still evaluated, but the resulting future is not +/// polled. /// 3. Concurrently await on the results for all remaining `<async /// expression>`s. 4. Once an `<async expression>` returns a value, attempt to /// apply the value to the provided `<pattern>`, if the pattern matches, diff --git a/monoio/src/net/tcp/listener.rs b/monoio/src/net/tcp/listener.rs index 476e679..af29509 100644 --- a/monoio/src/net/tcp/listener.rs +++ b/monoio/src/net/tcp/listener.rs @@ -114,7 +114,7 @@ impl TcpListener { let stream = TcpStream::from_shared_fd(SharedFd::new(fd as _)?); // Construct SocketAddr - let storage = completion.data.addr.0.as_ptr() as *const _ as *const libc::sockaddr_storage; + let storage = completion.data.addr.0.as_ptr(); let addr = unsafe { match (*storage).ss_family as libc::c_int { libc::AF_INET => { @@ -167,7 +167,7 @@ impl TcpListener { let stream = TcpStream::from_shared_fd(SharedFd::new(fd as _)?); // Construct SocketAddr - let storage = completion.data.addr.0.as_ptr() as *const _ as *const libc::sockaddr_storage; + let storage = completion.data.addr.0.as_ptr(); let addr = unsafe { match (*storage).ss_family as libc::c_int { libc::AF_INET => { diff --git a/monoio/src/net/tcp/stream.rs b/monoio/src/net/tcp/stream.rs index 359fc83..598d353 100644 --- a/monoio/src/net/tcp/stream.rs +++ b/monoio/src/net/tcp/stream.rs @@ -503,7 +503,7 @@ impl tokio::io::AsyncWrite for TcpStream { buf: &[u8], ) -> std::task::Poll<Result<usize, io::Error>> { unsafe { - let raw_buf = crate::buf::RawBuf::new(buf.as_ptr() as *const u8, buf.len()); + let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len()); let mut send = Op::send_raw(&self.fd, raw_buf); let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut send, cx)); diff --git a/monoio/src/net/unix/datagram/mod.rs b/monoio/src/net/unix/datagram/mod.rs index 67ea3d9..eb07981 100644 --- a/monoio/src/net/unix/datagram/mod.rs +++ b/monoio/src/net/unix/datagram/mod.rs @@ -14,6 +14,7 @@ use super::{ SocketAddr, }; use crate::{ + buf::{IoBuf, IoBufMut}, driver::{op::Op, shared_fd::SharedFd}, net::new_socket, }; @@ -61,7 +62,7 @@ impl UnixDatagram { sockaddr: libc::sockaddr_un, socklen: libc::socklen_t, ) -> io::Result<Self> { - let socket = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?; + let socket = new_socket(libc::AF_UNIX, libc::SOCK_DGRAM)?; let op = Op::connect_unix(SharedFd::new(socket)?, sockaddr, socklen)?; let completion = op.await; completion.meta.result?; @@ -119,6 +120,46 @@ impl UnixDatagram { let op = Op::poll_write(&self.fd, relaxed).unwrap(); op.wait().await } + + /// Sends data on the socket to the given address. On success, returns the + /// number of bytes written. + pub async fn send_to<T: IoBuf, P: AsRef<Path>>( + &self, + buf: T, + path: P, + ) -> crate::BufResult<usize, T> { + let addr = match crate::net::unix::socket_addr::socket_addr(path.as_ref()) { + Ok(addr) => addr, + Err(e) => return (Err(e), buf), + }; + let op = Op::send_msg_unix( + self.fd.clone(), + buf, + Some(SocketAddr::from_parts(addr.0, addr.1)), + ) + .unwrap(); + op.wait().await + } + + /// Receives a single datagram message on the socket. On success, returns the number + /// of bytes read and the origin. + pub async fn recv_from<T: IoBufMut>(&self, buf: T) -> crate::BufResult<(usize, SocketAddr), T> { + let op = Op::recv_msg_unix(self.fd.clone(), buf).unwrap(); + op.wait().await + } + + /// Sends data on the socket to the remote address to which it is connected. + pub async fn send<T: IoBuf>(&self, buf: T) -> crate::BufResult<usize, T> { + let op = Op::send_msg_unix(self.fd.clone(), buf, None).unwrap(); + op.wait().await + } + + /// Receives a single datagram message on the socket from the remote address to + /// which it is connected. On success, returns the number of bytes read. + pub async fn recv<T: IoBufMut>(&self, buf: T) -> crate::BufResult<usize, T> { + let op = Op::recv(self.fd.clone(), buf).unwrap(); + op.read().await + } } impl AsRawFd for UnixDatagram { diff --git a/monoio/src/net/unix/socket_addr.rs b/monoio/src/net/unix/socket_addr.rs index 70defc2..7169552 100644 --- a/monoio/src/net/unix/socket_addr.rs +++ b/monoio/src/net/unix/socket_addr.rs @@ -64,7 +64,24 @@ impl SocketAddr { Ok(SocketAddr::from_parts(sockaddr, socklen)) } - pub(crate) fn from_parts(sockaddr: libc::sockaddr_un, socklen: libc::socklen_t) -> SocketAddr { + pub(crate) fn from_parts( + sockaddr: libc::sockaddr_un, + mut socklen: libc::socklen_t, + ) -> SocketAddr { + fn sun_path_offset(addr: &libc::sockaddr_un) -> usize { + let base: usize = (addr as *const libc::sockaddr_un).cast::<()>() as usize; + let path: usize = (&addr.sun_path as *const libc::c_char).cast::<()>() as usize; + path - base + } + + if socklen == 0 { + // When there is a datagram from unnamed unix socket + // linux returns zero bytes of address + socklen = sun_path_offset(&sockaddr) as libc::socklen_t; // i.e., zero-length address + } else if sockaddr.sun_family != libc::AF_UNIX as libc::sa_family_t { + panic!("file descriptor did not correspond to a Unix socket"); + } + SocketAddr { sockaddr, socklen } } @@ -77,6 +94,7 @@ impl SocketAddr { /// Documentation reflected in [`SocketAddr`] /// /// [`SocketAddr`]: std::os::unix::net::SocketAddr + #[cfg(target_os = "linux")] #[inline] pub fn is_unnamed(&self) -> bool { matches!(self.address(), AddressKind::Unnamed) @@ -87,6 +105,7 @@ impl SocketAddr { /// Documentation reflected in [`SocketAddr`] /// /// [`SocketAddr`]: std::os::unix::net::SocketAddr + #[cfg(target_os = "linux")] #[inline] pub fn as_pathname(&self) -> Option<&Path> { if let AddressKind::Pathname(path) = self.address() { @@ -100,6 +119,7 @@ impl SocketAddr { /// without the leading null byte. // Link to std::os::unix::net::SocketAddr pending // https://github.com/rust-lang/rust/issues/85410. + #[cfg(target_os = "linux")] #[inline] pub fn as_abstract_namespace(&self) -> Option<&[u8]> { if let AddressKind::Abstract(path) = self.address() { @@ -108,6 +128,16 @@ impl SocketAddr { None } } + + #[inline] + pub(crate) fn as_ptr(&self) -> *const libc::sockaddr_un { + &self.sockaddr as *const _ + } + + #[inline] + pub(crate) fn len(&self) -> libc::socklen_t { + self.socklen + } } impl fmt::Debug for SocketAddr { diff --git a/monoio/src/net/unix/stream.rs b/monoio/src/net/unix/stream.rs index 4cb61e5..89b1b1c 100644 --- a/monoio/src/net/unix/stream.rs +++ b/monoio/src/net/unix/stream.rs @@ -369,7 +369,7 @@ impl tokio::io::AsyncWrite for UnixStream { buf: &[u8], ) -> std::task::Poll<Result<usize, io::Error>> { unsafe { - let raw_buf = crate::buf::RawBuf::new(buf.as_ptr() as *const u8, buf.len()); + let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len()); let mut send = Op::send_raw(&self.fd, raw_buf); let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut send, cx)); diff --git a/monoio/src/task/raw.rs b/monoio/src/task/raw.rs index 71c54a3..ccdc8e2 100644 --- a/monoio/src/task/raw.rs +++ b/monoio/src/task/raw.rs @@ -12,7 +12,7 @@ pub(crate) struct RawTask { impl Clone for RawTask { fn clone(&self) -> Self { - RawTask { ptr: self.ptr } + *self } } diff --git a/monoio/src/time/driver/mod.rs b/monoio/src/time/driver/mod.rs index 71286c8..78dfc51 100644 --- a/monoio/src/time/driver/mod.rs +++ b/monoio/src/time/driver/mod.rs @@ -7,7 +7,7 @@ //! Time driver mod entry; -pub(self) use self::entry::{EntryList, TimerEntry, TimerHandle, TimerShared}; +use self::entry::{EntryList, TimerEntry, TimerHandle, TimerShared}; mod handle; pub(crate) use self::handle::Handle; @@ -93,7 +93,7 @@ pub struct TimeDriver<D: 'static> { /// A structure which handles conversion from Instants to u64 timestamps. #[derive(Debug, Clone)] -pub(self) struct ClockTime { +struct ClockTime { clock: super::clock::Clock, start_time: Instant, } diff --git a/monoio/tests/unix_datagram.rs b/monoio/tests/unix_datagram.rs new file mode 100644 index 0000000..58cf6b0 --- /dev/null +++ b/monoio/tests/unix_datagram.rs @@ -0,0 +1,27 @@ +use monoio::net::unix::UnixDatagram; + +#[monoio::test_all] +async fn accept_send_recv() -> std::io::Result<()> { + let dir = tempfile::Builder::new() + .prefix("monoio-unix-datagram-tests") + .tempdir() + .unwrap(); + let sock_path = dir.path().join("dgram.sock"); + + let dgram1 = UnixDatagram::bind(&sock_path)?; + let dgram2 = UnixDatagram::connect(&sock_path).await?; + + dgram2.send(b"hello").await.0.unwrap(); + let (_res, buf) = dgram1.recv_from(vec![0; 100]).await; + assert_eq!(buf, b"hello"); + + #[cfg(target_os = "linux")] + assert!(_res.unwrap().1.is_unnamed()); + + let dgram3 = UnixDatagram::unbound()?; + dgram3.send_to(b"hello2", &sock_path).await.0.unwrap(); + let (res, buf) = dgram1.recv(vec![0; 100]).await; + assert_eq!(buf, b"hello2"); + assert_eq!(res.unwrap(), 6); + Ok(()) +} |
