summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorihc童鞋@提不起劲 <[email protected]>2023-07-03 15:16:00 +0800
committerGitHub <[email protected]>2023-07-03 15:16:00 +0800
commit50147fad39f1e68546a12c9a6b4066aed4f44cee (patch)
treecbba5c29a6969eebf0f8cbb88df6d00d56e401cc
parentaba979c378d72380a0deb59c7b7454d2dfc0bbfa (diff)
feat: support unix datagram (#183)
* feat: support unix datagram * fix: clippy
-rw-r--r--monoio-compat/src/tcp_unsafe.rs2
-rw-r--r--monoio/src/driver/op/recv.rs80
-rw-r--r--monoio/src/driver/op/send.rs73
-rw-r--r--monoio/src/macros/select.rs13
-rw-r--r--monoio/src/net/tcp/listener.rs4
-rw-r--r--monoio/src/net/tcp/stream.rs2
-rw-r--r--monoio/src/net/unix/datagram/mod.rs43
-rw-r--r--monoio/src/net/unix/socket_addr.rs32
-rw-r--r--monoio/src/net/unix/stream.rs2
-rw-r--r--monoio/src/task/raw.rs2
-rw-r--r--monoio/src/time/driver/mod.rs4
-rw-r--r--monoio/tests/unix_datagram.rs27
12 files changed, 265 insertions, 19 deletions
diff --git a/monoio-compat/src/tcp_unsafe.rs b/monoio-compat/src/tcp_unsafe.rs
index 7f9e0cb..30dfca8 100644
--- a/monoio-compat/src/tcp_unsafe.rs
+++ b/monoio-compat/src/tcp_unsafe.rs
@@ -108,7 +108,7 @@ impl tokio::io::AsyncWrite for TcpStreamCompat {
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.get_mut();
- let (ptr, len) = (buf.as_ptr() as *const u8, buf.len());
+ let (ptr, len) = (buf.as_ptr(), buf.len());
// Set or check write_dst
// Note: the check can not prevent memory crash when user misuse it.
diff --git a/monoio/src/driver/op/recv.rs b/monoio/src/driver/op/recv.rs
index 3f8885f..ce6ecc8 100644
--- a/monoio/src/driver/op/recv.rs
+++ b/monoio/src/driver/op/recv.rs
@@ -13,7 +13,7 @@ use {
};
use super::{super::shared_fd::SharedFd, Op, OpAble};
-use crate::{buf::IoBufMut, BufResult};
+use crate::{buf::IoBufMut, net::unix::SocketAddr as UnixSocketAddr, BufResult};
pub(crate) struct Recv<T> {
/// Holds a strong ref to the FD, preventing the file from being closed
@@ -181,3 +181,81 @@ impl<T: IoBufMut> OpAble for RecvMsg<T> {
syscall_u32!(recvmsg(fd, &mut self.info.2 as *mut _, 0))
}
}
+
+pub(crate) struct RecvMsgUnix<T> {
+ /// Holds a strong ref to the FD, preventing the file from being closed
+ /// while the operation is in-flight.
+ #[allow(unused)]
+ fd: SharedFd,
+
+ /// Reference to the in-flight buffer.
+ pub(crate) buf: T,
+ pub(crate) info: Box<(
+ MaybeUninit<libc::sockaddr_storage>,
+ [libc::iovec; 1],
+ libc::msghdr,
+ )>,
+}
+
+impl<T: IoBufMut> Op<RecvMsgUnix<T>> {
+ pub(crate) fn recv_msg_unix(fd: SharedFd, mut buf: T) -> io::Result<Self> {
+ let iovec = [libc::iovec {
+ iov_base: buf.write_ptr() as *mut _,
+ iov_len: buf.bytes_total(),
+ }];
+ let mut info: Box<(
+ MaybeUninit<libc::sockaddr_storage>,
+ [libc::iovec; 1],
+ libc::msghdr,
+ )> = Box::new((MaybeUninit::uninit(), iovec, unsafe { std::mem::zeroed() }));
+
+ info.2.msg_iov = info.1.as_mut_ptr();
+ info.2.msg_iovlen = 1;
+ info.2.msg_name = &mut info.0 as *mut _ as *mut libc::c_void;
+ info.2.msg_namelen = std::mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t;
+
+ Op::submit_with(RecvMsgUnix { fd, buf, info })
+ }
+
+ pub(crate) async fn wait(self) -> BufResult<(usize, UnixSocketAddr), T> {
+ let complete = self.await;
+ let res = complete.meta.result.map(|v| v as _);
+ let mut buf = complete.data.buf;
+
+ let res = res.map(|n| {
+ let storage = unsafe { complete.data.info.0.assume_init() };
+ let name_len = complete.data.info.2.msg_namelen;
+
+ let addr = unsafe {
+ let addr: &libc::sockaddr_un = transmute(&storage);
+ UnixSocketAddr::from_parts(*addr, name_len)
+ };
+
+ // Safety: the kernel wrote `n` bytes to the buffer.
+ unsafe {
+ buf.set_init(n);
+ }
+
+ (n, addr)
+ });
+ (res, buf)
+ }
+}
+
+impl<T: IoBufMut> OpAble for RecvMsgUnix<T> {
+ #[cfg(all(target_os = "linux", feature = "iouring"))]
+ fn uring_op(&mut self) -> io_uring::squeue::Entry {
+ opcode::RecvMsg::new(types::Fd(self.fd.raw_fd()), &mut self.info.2 as *mut _).build()
+ }
+
+ #[cfg(all(unix, feature = "legacy"))]
+ fn legacy_interest(&self) -> Option<(Direction, usize)> {
+ self.fd.registered_index().map(|idx| (Direction::Read, idx))
+ }
+
+ #[cfg(all(unix, feature = "legacy"))]
+ fn legacy_call(&mut self) -> io::Result<u32> {
+ let fd = self.fd.as_raw_fd();
+ syscall_u32!(recvmsg(fd, &mut self.info.2 as *mut _, 0))
+ }
+}
diff --git a/monoio/src/driver/op/send.rs b/monoio/src/driver/op/send.rs
index bd7222b..0f3aeed 100644
--- a/monoio/src/driver/op/send.rs
+++ b/monoio/src/driver/op/send.rs
@@ -10,7 +10,7 @@ use {
};
use super::{super::shared_fd::SharedFd, Op, OpAble};
-use crate::{buf::IoBuf, BufResult};
+use crate::{buf::IoBuf, net::unix::SocketAddr as UnixSocketAddr, BufResult};
pub(crate) struct Send<T> {
/// Holds a strong ref to the FD, preventing the file from being closed
@@ -171,3 +171,74 @@ impl<T: IoBuf> OpAble for SendMsg<T> {
syscall_u32!(sendmsg(fd, &mut self.info.2 as *mut _, 0))
}
}
+
+pub(crate) struct SendMsgUnix<T> {
+ /// Holds a strong ref to the FD, preventing the file from being closed
+ /// while the operation is in-flight.
+ #[allow(unused)]
+ fd: SharedFd,
+
+ /// Reference to the in-flight buffer.
+ pub(crate) buf: T,
+ pub(crate) info: Box<(Option<UnixSocketAddr>, [libc::iovec; 1], libc::msghdr)>,
+}
+
+impl<T: IoBuf> Op<SendMsgUnix<T>> {
+ pub(crate) fn send_msg_unix(
+ fd: SharedFd,
+ buf: T,
+ socket_addr: Option<UnixSocketAddr>,
+ ) -> io::Result<Self> {
+ let iovec = [libc::iovec {
+ iov_base: buf.read_ptr() as *const _ as *mut _,
+ iov_len: buf.bytes_init(),
+ }];
+ let mut info: Box<(Option<UnixSocketAddr>, [libc::iovec; 1], libc::msghdr)> =
+ Box::new((socket_addr.map(Into::into), iovec, unsafe {
+ std::mem::zeroed()
+ }));
+
+ info.2.msg_iov = info.1.as_mut_ptr();
+ info.2.msg_iovlen = 1;
+
+ match info.0.as_ref() {
+ Some(socket_addr) => {
+ info.2.msg_name = socket_addr.as_ptr() as *mut libc::c_void;
+ info.2.msg_namelen = socket_addr.len();
+ }
+ None => {
+ info.2.msg_name = std::ptr::null_mut();
+ info.2.msg_namelen = 0;
+ }
+ }
+
+ Op::submit_with(SendMsgUnix { fd, buf, info })
+ }
+
+ pub(crate) async fn wait(self) -> BufResult<usize, T> {
+ let complete = self.await;
+ let res = complete.meta.result.map(|v| v as _);
+ let buf = complete.data.buf;
+ (res, buf)
+ }
+}
+
+impl<T: IoBuf> OpAble for SendMsgUnix<T> {
+ #[cfg(all(target_os = "linux", feature = "iouring"))]
+ fn uring_op(&mut self) -> io_uring::squeue::Entry {
+ opcode::SendMsg::new(types::Fd(self.fd.raw_fd()), &mut self.info.2 as *mut _).build()
+ }
+
+ #[cfg(all(unix, feature = "legacy"))]
+ fn legacy_interest(&self) -> Option<(Direction, usize)> {
+ self.fd
+ .registered_index()
+ .map(|idx| (Direction::Write, idx))
+ }
+
+ #[cfg(all(unix, feature = "legacy"))]
+ fn legacy_call(&mut self) -> io::Result<u32> {
+ let fd = self.fd.as_raw_fd();
+ syscall_u32!(sendmsg(fd, &mut self.info.2 as *mut _, 0))
+ }
+}
diff --git a/monoio/src/macros/select.rs b/monoio/src/macros/select.rs
index 3cddd12..8736ac8 100644
--- a/monoio/src/macros/select.rs
+++ b/monoio/src/macros/select.rs
@@ -30,13 +30,12 @@
///
/// The complete lifecycle of a `select!` expression is as follows:
///
-/// 1. Evaluate all provided `<precondition>` expressions. If the precondition
-/// returns `false`, disable the branch for the remainder of the current call
-/// to `select!`. Re-entering `select!` due to a loop clears the "disabled"
-/// state.
-/// 2. Aggregate the `<async expression>`s from each branch, including the
-/// disabled ones. If the branch is disabled, `<async expression>` is still
-/// evaluated, but the resulting future is not polled.
+/// 1. Evaluate all provided `<precondition>` expressions. If the precondition returns `false`,
+/// disable the branch for the remainder of the current call to `select!`. Re-entering `select!`
+/// due to a loop clears the "disabled" state.
+/// 2. Aggregate the `<async expression>`s from each branch, including the disabled ones. If the
+/// branch is disabled, `<async expression>` is still evaluated, but the resulting future is not
+/// polled.
/// 3. Concurrently await on the results for all remaining `<async
/// expression>`s. 4. Once an `<async expression>` returns a value, attempt to
/// apply the value to the provided `<pattern>`, if the pattern matches,
diff --git a/monoio/src/net/tcp/listener.rs b/monoio/src/net/tcp/listener.rs
index 476e679..af29509 100644
--- a/monoio/src/net/tcp/listener.rs
+++ b/monoio/src/net/tcp/listener.rs
@@ -114,7 +114,7 @@ impl TcpListener {
let stream = TcpStream::from_shared_fd(SharedFd::new(fd as _)?);
// Construct SocketAddr
- let storage = completion.data.addr.0.as_ptr() as *const _ as *const libc::sockaddr_storage;
+ let storage = completion.data.addr.0.as_ptr();
let addr = unsafe {
match (*storage).ss_family as libc::c_int {
libc::AF_INET => {
@@ -167,7 +167,7 @@ impl TcpListener {
let stream = TcpStream::from_shared_fd(SharedFd::new(fd as _)?);
// Construct SocketAddr
- let storage = completion.data.addr.0.as_ptr() as *const _ as *const libc::sockaddr_storage;
+ let storage = completion.data.addr.0.as_ptr();
let addr = unsafe {
match (*storage).ss_family as libc::c_int {
libc::AF_INET => {
diff --git a/monoio/src/net/tcp/stream.rs b/monoio/src/net/tcp/stream.rs
index 359fc83..598d353 100644
--- a/monoio/src/net/tcp/stream.rs
+++ b/monoio/src/net/tcp/stream.rs
@@ -503,7 +503,7 @@ impl tokio::io::AsyncWrite for TcpStream {
buf: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
unsafe {
- let raw_buf = crate::buf::RawBuf::new(buf.as_ptr() as *const u8, buf.len());
+ let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len());
let mut send = Op::send_raw(&self.fd, raw_buf);
let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut send, cx));
diff --git a/monoio/src/net/unix/datagram/mod.rs b/monoio/src/net/unix/datagram/mod.rs
index 67ea3d9..eb07981 100644
--- a/monoio/src/net/unix/datagram/mod.rs
+++ b/monoio/src/net/unix/datagram/mod.rs
@@ -14,6 +14,7 @@ use super::{
SocketAddr,
};
use crate::{
+ buf::{IoBuf, IoBufMut},
driver::{op::Op, shared_fd::SharedFd},
net::new_socket,
};
@@ -61,7 +62,7 @@ impl UnixDatagram {
sockaddr: libc::sockaddr_un,
socklen: libc::socklen_t,
) -> io::Result<Self> {
- let socket = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?;
+ let socket = new_socket(libc::AF_UNIX, libc::SOCK_DGRAM)?;
let op = Op::connect_unix(SharedFd::new(socket)?, sockaddr, socklen)?;
let completion = op.await;
completion.meta.result?;
@@ -119,6 +120,46 @@ impl UnixDatagram {
let op = Op::poll_write(&self.fd, relaxed).unwrap();
op.wait().await
}
+
+ /// Sends data on the socket to the given address. On success, returns the
+ /// number of bytes written.
+ pub async fn send_to<T: IoBuf, P: AsRef<Path>>(
+ &self,
+ buf: T,
+ path: P,
+ ) -> crate::BufResult<usize, T> {
+ let addr = match crate::net::unix::socket_addr::socket_addr(path.as_ref()) {
+ Ok(addr) => addr,
+ Err(e) => return (Err(e), buf),
+ };
+ let op = Op::send_msg_unix(
+ self.fd.clone(),
+ buf,
+ Some(SocketAddr::from_parts(addr.0, addr.1)),
+ )
+ .unwrap();
+ op.wait().await
+ }
+
+ /// Receives a single datagram message on the socket. On success, returns the number
+ /// of bytes read and the origin.
+ pub async fn recv_from<T: IoBufMut>(&self, buf: T) -> crate::BufResult<(usize, SocketAddr), T> {
+ let op = Op::recv_msg_unix(self.fd.clone(), buf).unwrap();
+ op.wait().await
+ }
+
+ /// Sends data on the socket to the remote address to which it is connected.
+ pub async fn send<T: IoBuf>(&self, buf: T) -> crate::BufResult<usize, T> {
+ let op = Op::send_msg_unix(self.fd.clone(), buf, None).unwrap();
+ op.wait().await
+ }
+
+ /// Receives a single datagram message on the socket from the remote address to
+ /// which it is connected. On success, returns the number of bytes read.
+ pub async fn recv<T: IoBufMut>(&self, buf: T) -> crate::BufResult<usize, T> {
+ let op = Op::recv(self.fd.clone(), buf).unwrap();
+ op.read().await
+ }
}
impl AsRawFd for UnixDatagram {
diff --git a/monoio/src/net/unix/socket_addr.rs b/monoio/src/net/unix/socket_addr.rs
index 70defc2..7169552 100644
--- a/monoio/src/net/unix/socket_addr.rs
+++ b/monoio/src/net/unix/socket_addr.rs
@@ -64,7 +64,24 @@ impl SocketAddr {
Ok(SocketAddr::from_parts(sockaddr, socklen))
}
- pub(crate) fn from_parts(sockaddr: libc::sockaddr_un, socklen: libc::socklen_t) -> SocketAddr {
+ pub(crate) fn from_parts(
+ sockaddr: libc::sockaddr_un,
+ mut socklen: libc::socklen_t,
+ ) -> SocketAddr {
+ fn sun_path_offset(addr: &libc::sockaddr_un) -> usize {
+ let base: usize = (addr as *const libc::sockaddr_un).cast::<()>() as usize;
+ let path: usize = (&addr.sun_path as *const libc::c_char).cast::<()>() as usize;
+ path - base
+ }
+
+ if socklen == 0 {
+ // When there is a datagram from unnamed unix socket
+ // linux returns zero bytes of address
+ socklen = sun_path_offset(&sockaddr) as libc::socklen_t; // i.e., zero-length address
+ } else if sockaddr.sun_family != libc::AF_UNIX as libc::sa_family_t {
+ panic!("file descriptor did not correspond to a Unix socket");
+ }
+
SocketAddr { sockaddr, socklen }
}
@@ -77,6 +94,7 @@ impl SocketAddr {
/// Documentation reflected in [`SocketAddr`]
///
/// [`SocketAddr`]: std::os::unix::net::SocketAddr
+ #[cfg(target_os = "linux")]
#[inline]
pub fn is_unnamed(&self) -> bool {
matches!(self.address(), AddressKind::Unnamed)
@@ -87,6 +105,7 @@ impl SocketAddr {
/// Documentation reflected in [`SocketAddr`]
///
/// [`SocketAddr`]: std::os::unix::net::SocketAddr
+ #[cfg(target_os = "linux")]
#[inline]
pub fn as_pathname(&self) -> Option<&Path> {
if let AddressKind::Pathname(path) = self.address() {
@@ -100,6 +119,7 @@ impl SocketAddr {
/// without the leading null byte.
// Link to std::os::unix::net::SocketAddr pending
// https://github.com/rust-lang/rust/issues/85410.
+ #[cfg(target_os = "linux")]
#[inline]
pub fn as_abstract_namespace(&self) -> Option<&[u8]> {
if let AddressKind::Abstract(path) = self.address() {
@@ -108,6 +128,16 @@ impl SocketAddr {
None
}
}
+
+ #[inline]
+ pub(crate) fn as_ptr(&self) -> *const libc::sockaddr_un {
+ &self.sockaddr as *const _
+ }
+
+ #[inline]
+ pub(crate) fn len(&self) -> libc::socklen_t {
+ self.socklen
+ }
}
impl fmt::Debug for SocketAddr {
diff --git a/monoio/src/net/unix/stream.rs b/monoio/src/net/unix/stream.rs
index 4cb61e5..89b1b1c 100644
--- a/monoio/src/net/unix/stream.rs
+++ b/monoio/src/net/unix/stream.rs
@@ -369,7 +369,7 @@ impl tokio::io::AsyncWrite for UnixStream {
buf: &[u8],
) -> std::task::Poll<Result<usize, io::Error>> {
unsafe {
- let raw_buf = crate::buf::RawBuf::new(buf.as_ptr() as *const u8, buf.len());
+ let raw_buf = crate::buf::RawBuf::new(buf.as_ptr(), buf.len());
let mut send = Op::send_raw(&self.fd, raw_buf);
let ret = ready!(crate::driver::op::PollLegacy::poll_legacy(&mut send, cx));
diff --git a/monoio/src/task/raw.rs b/monoio/src/task/raw.rs
index 71c54a3..ccdc8e2 100644
--- a/monoio/src/task/raw.rs
+++ b/monoio/src/task/raw.rs
@@ -12,7 +12,7 @@ pub(crate) struct RawTask {
impl Clone for RawTask {
fn clone(&self) -> Self {
- RawTask { ptr: self.ptr }
+ *self
}
}
diff --git a/monoio/src/time/driver/mod.rs b/monoio/src/time/driver/mod.rs
index 71286c8..78dfc51 100644
--- a/monoio/src/time/driver/mod.rs
+++ b/monoio/src/time/driver/mod.rs
@@ -7,7 +7,7 @@
//! Time driver
mod entry;
-pub(self) use self::entry::{EntryList, TimerEntry, TimerHandle, TimerShared};
+use self::entry::{EntryList, TimerEntry, TimerHandle, TimerShared};
mod handle;
pub(crate) use self::handle::Handle;
@@ -93,7 +93,7 @@ pub struct TimeDriver<D: 'static> {
/// A structure which handles conversion from Instants to u64 timestamps.
#[derive(Debug, Clone)]
-pub(self) struct ClockTime {
+struct ClockTime {
clock: super::clock::Clock,
start_time: Instant,
}
diff --git a/monoio/tests/unix_datagram.rs b/monoio/tests/unix_datagram.rs
new file mode 100644
index 0000000..58cf6b0
--- /dev/null
+++ b/monoio/tests/unix_datagram.rs
@@ -0,0 +1,27 @@
+use monoio::net::unix::UnixDatagram;
+
+#[monoio::test_all]
+async fn accept_send_recv() -> std::io::Result<()> {
+ let dir = tempfile::Builder::new()
+ .prefix("monoio-unix-datagram-tests")
+ .tempdir()
+ .unwrap();
+ let sock_path = dir.path().join("dgram.sock");
+
+ let dgram1 = UnixDatagram::bind(&sock_path)?;
+ let dgram2 = UnixDatagram::connect(&sock_path).await?;
+
+ dgram2.send(b"hello").await.0.unwrap();
+ let (_res, buf) = dgram1.recv_from(vec![0; 100]).await;
+ assert_eq!(buf, b"hello");
+
+ #[cfg(target_os = "linux")]
+ assert!(_res.unwrap().1.is_unnamed());
+
+ let dgram3 = UnixDatagram::unbound()?;
+ dgram3.send_to(b"hello2", &sock_path).await.0.unwrap();
+ let (res, buf) = dgram1.recv(vec![0; 100]).await;
+ assert_eq!(buf, b"hello2");
+ assert_eq!(res.unwrap(), 6);
+ Ok(())
+}