diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/executor.rs | 174 | ||||
| -rw-r--r-- | src/lib.rs | 6 | ||||
| -rw-r--r-- | src/reactor.rs | 117 | ||||
| -rw-r--r-- | src/tcp.rs | 162 |
4 files changed, 459 insertions, 0 deletions
diff --git a/src/executor.rs b/src/executor.rs new file mode 100644 index 0000000..42cd00f --- /dev/null +++ b/src/executor.rs @@ -0,0 +1,174 @@ +use std::{ + cell::RefCell, + collections::VecDeque, + marker::PhantomData, + mem, + rc::Rc, + task::{RawWaker, RawWakerVTable, Waker, Context}, pin::Pin, +}; + +use futures::{future::LocalBoxFuture, Future, FutureExt}; + +use crate::reactor::Reactor; + +scoped_tls::scoped_thread_local!(pub(crate) static EX: Executor); + +pub struct Executor { + local_queue: TaskQueue, + pub(crate) reactor: Rc<RefCell<Reactor>>, + + /// Make sure the type is `!Send` and `!Sync`. + _marker: PhantomData<Rc<()>>, +} + +impl Default for Executor { + fn default() -> Self { + Self::new() + } +} + + +impl Executor { + pub fn new() -> Self { + Self { + local_queue: TaskQueue::default(), + reactor: Rc::new(RefCell::new(Reactor::default())), + + _marker: PhantomData, + } + } + + pub fn spawn(fut: impl Future<Output = ()> + 'static) { + let t = Rc::new(Task { + future: RefCell::new(fut.boxed_local()), + }); + EX.with(|ex| ex.local_queue.push(t)); + } + + pub fn block_on<F, T, O>(&self, f: F) -> O + where + F: Fn() -> T, + T: Future<Output = O> + 'static, + { + let _waker = waker_fn::waker_fn(|| {}); + let cx = &mut Context::from_waker(&_waker); + + EX.set(self, || { + let fut = f(); + pin_utils::pin_mut!(fut); + loop { + // return if the outer future is ready + if let std::task::Poll::Ready(t) = fut.as_mut().poll(cx) { + break t; + } + + // consume all tasks + while let Some(t) = self.local_queue.pop() { + let future = t.future.borrow_mut(); + let w = waker(t.clone()); + let mut context = Context::from_waker(&w); + let _ = Pin::new(future).as_mut().poll(&mut context); + } + + // no task to execute now, it may ready + if let std::task::Poll::Ready(t) = fut.as_mut().poll(cx) { + break t; + } + + // block for io + self.reactor.borrow_mut().wait(); + } + }) + } +} + +pub struct TaskQueue { + queue: RefCell<VecDeque<Rc<Task>>>, +} + +impl Default for TaskQueue { + fn default() -> Self { + Self::new() + } +} + +impl TaskQueue { + pub fn new() -> Self { + const DEFAULT_TASK_QUEUE_SIZE: usize = 4096; + Self::new_with_capacity(DEFAULT_TASK_QUEUE_SIZE) + } + pub fn new_with_capacity(capacity: usize) -> Self { + Self { + queue: RefCell::new(VecDeque::with_capacity(capacity)), + } + } + + pub(crate) fn push(&self, runnable: Rc<Task>) { + println!("add task"); + self.queue.borrow_mut().push_back(runnable); + } + + pub(crate) fn pop(&self) -> Option<Rc<Task>> { + println!("remove task"); + self.queue.borrow_mut().pop_front() + } +} + +pub struct Task { + future: RefCell<LocalBoxFuture<'static, ()>>, +} + +fn waker(wake: Rc<Task>) -> Waker { + let ptr = Rc::into_raw(wake) as *const (); + let vtable = &Helper::VTABLE; + unsafe { Waker::from_raw(RawWaker::new(ptr, vtable)) } +} + +impl Task { + fn wake_(self: Rc<Self>) { + Self::wake_by_ref_(&self) + } + + fn wake_by_ref_(self: &Rc<Self>) { + EX.with(|ex| ex.local_queue.push(self.clone())); + } +} + +struct Helper; + +impl Helper { + const VTABLE: RawWakerVTable = RawWakerVTable::new( + Self::clone_waker, + Self::wake, + Self::wake_by_ref, + Self::drop_waker, + ); + + unsafe fn clone_waker(data: *const ()) -> RawWaker { + increase_refcount(data); + let vtable = &Self::VTABLE; + RawWaker::new(data, vtable) + } + + unsafe fn wake(ptr: *const ()) { + let rc = Rc::from_raw(ptr as *const Task); + rc.wake_(); + } + + unsafe fn wake_by_ref(ptr: *const ()) { + let rc = mem::ManuallyDrop::new(Rc::from_raw(ptr as *const Task)); + rc.wake_by_ref_(); + } + + unsafe fn drop_waker(ptr: *const ()) { + drop(Rc::from_raw(ptr as *const Task)); + } +} + +#[allow(clippy::redundant_clone)] // The clone here isn't actually redundant. +unsafe fn increase_refcount(data: *const ()) { + // Retain Rc, but don't touch refcount by wrapping in ManuallyDrop + let rc = mem::ManuallyDrop::new(Rc::<Task>::from_raw(data as *const Task)); + // Now increase refcount, but don't drop new refcount either + let _rc_clone: mem::ManuallyDrop<_> = rc.clone(); +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..6ebc7f8 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,6 @@ +#![allow(unused)] + +pub mod executor; +pub mod tcp; + +mod reactor; diff --git a/src/reactor.rs b/src/reactor.rs new file mode 100644 index 0000000..51133ed --- /dev/null +++ b/src/reactor.rs @@ -0,0 +1,117 @@ +use std::{ + cell::RefCell, + os::unix::prelude::{AsRawFd, RawFd}, + rc::Rc, + task::{Context, Waker}, +}; + +use polling::{Event, Poller}; + +#[inline] +pub(crate) fn get_reactor() -> Rc<RefCell<Reactor>> { + crate::executor::EX.with(|ex| ex.reactor.clone()) +} + +#[derive(Debug)] +pub struct Reactor { + poller: Poller, + waker_mapping: rustc_hash::FxHashMap<u64, Waker>, + + buffer: Vec<Event>, +} + +impl Reactor { + pub fn new() -> Self { + Self { + poller: Poller::new().unwrap(), + waker_mapping: Default::default(), + + buffer: Vec::with_capacity(2048), + } + } + + // Epoll related + pub fn add(&mut self, fd: RawFd) { + println!("[reactor] add fd {}", fd); + + let flags = + nix::fcntl::OFlag::from_bits(nix::fcntl::fcntl(fd, nix::fcntl::F_GETFL).unwrap()) + .unwrap(); + let flags_nonblocking = flags | nix::fcntl::OFlag::O_NONBLOCK; + nix::fcntl::fcntl(fd, nix::fcntl::F_SETFL(flags_nonblocking)).unwrap(); + self.poller + .add(fd, polling::Event::none(fd as usize)) + .unwrap(); + } + + pub fn modify_readable(&mut self, fd: RawFd, cx: &mut Context) { + println!("[reactor] modify_readable fd {} token {}", fd, fd * 2); + + self.push_completion(fd as u64 * 2, cx); + let event = polling::Event::readable(fd as usize); + self.poller.modify(fd, event); + } + + pub fn modify_writable(&mut self, fd: RawFd, cx: &mut Context) { + println!("[reactor] modify_writable fd {}, token {}", fd, fd * 2 + 1); + + self.push_completion(fd as u64 * 2 + 1, cx); + let event = polling::Event::writable(fd as usize); + self.poller.modify(fd, event); + } + + pub fn wait(&mut self) { + println!("[reactor] waiting"); + self.poller.wait(&mut self.buffer, None); + println!("[reactor] wait done"); + + for i in 0..self.buffer.len() { + let event = self.buffer.swap_remove(0); + if event.readable { + if let Some(waker) = self.waker_mapping.remove(&(event.key as u64 * 2)) { + println!( + "[reactor token] fd {} read waker token {} removed and woken", + event.key, + event.key * 2 + ); + waker.wake(); + } + } + if event.writable { + if let Some(waker) = self.waker_mapping.remove(&(event.key as u64 * 2 + 1)) { + println!( + "[reactor token] fd {} write waker token {} removed and woken", + event.key, + event.key * 2 + 1 + ); + waker.wake(); + } + } + } + } + + pub fn delete(&mut self, fd: RawFd) { + println!("[reactor] delete fd {}", fd); + + self.waker_mapping.remove(&(fd as u64 * 2)); + self.waker_mapping.remove(&(fd as u64 * 2 + 1)); + println!( + "[reactor token] fd {} wakers token {}, {} removed", + fd, + fd * 2, + fd * 2 + 1 + ); + } + + fn push_completion(&mut self, token: u64, cx: &mut Context) { + println!("[reactor token] token {} waker saved", token); + + self.waker_mapping.insert(token, cx.waker().clone()); + } +} + +impl Default for Reactor { + fn default() -> Self { + Self::new() + } +} diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 0000000..46ef06a --- /dev/null +++ b/src/tcp.rs @@ -0,0 +1,162 @@ +use std::{ + cell::RefCell, + io::{self, Read, Write}, + net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream, ToSocketAddrs}, + os::unix::prelude::AsRawFd, + rc::{Rc, Weak}, + task::Poll, +}; + +use futures::Stream; +use socket2::{Domain, Protocol, Socket, Type}; + +use crate::{reactor::get_reactor, reactor::Reactor}; + +#[derive(Debug)] +pub struct TcpListener { + reactor: Weak<RefCell<Reactor>>, + listener: StdTcpListener, +} + +impl TcpListener { + pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> { + let addr = addr + .to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?; + + let domain = if addr.is_ipv6() { + Domain::IPV6 + } else { + Domain::IPV4 + }; + let sk = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + let addr = socket2::SockAddr::from(addr); + sk.set_reuse_address(true)?; + sk.bind(&addr)?; + sk.listen(1024)?; + + // add fd to reactor + let reactor = get_reactor(); + reactor.borrow_mut().add(sk.as_raw_fd()); + + println!("tcp bind with fd {}", sk.as_raw_fd()); + Ok(Self { + reactor: Rc::downgrade(&reactor), + listener: sk.into(), + }) + } +} + +impl Stream for TcpListener { + type Item = std::io::Result<(TcpStream, SocketAddr)>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Option<Self::Item>> { + match self.listener.accept() { + Ok((stream, addr)) => Poll::Ready(Some(Ok((stream.into(), addr)))), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // modify reactor to register interest + let reactor = self.reactor.upgrade().unwrap(); + reactor + .borrow_mut() + .modify_readable(self.listener.as_raw_fd(), cx); + Poll::Pending + } + Err(e) => std::task::Poll::Ready(Some(Err(e))), + } + } +} + +#[derive(Debug)] +pub struct TcpStream { + stream: StdTcpStream, +} + +impl From<StdTcpStream> for TcpStream { + fn from(stream: StdTcpStream) -> Self { + let reactor = get_reactor(); + reactor.borrow_mut().add(stream.as_raw_fd()); + Self { stream } + } +} + +impl Drop for TcpStream { + fn drop(&mut self) { + println!("drop"); + let reactor = get_reactor(); + reactor.borrow_mut().delete(self.stream.as_raw_fd()); + } +} + +impl tokio::io::AsyncRead for TcpStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let fd = self.stream.as_raw_fd(); + unsafe { + let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]); + println!("read for fd {}", fd); + match self.stream.read(b) { + Ok(n) => { + println!("read for fd {} done, {}", fd, n); + buf.assume_init(n); + buf.advance(n); + Poll::Ready(Ok(())) + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + println!("read for fd {} done WouldBlock", fd); + // modify reactor to register interest + let reactor = get_reactor(); + reactor + .borrow_mut() + .modify_readable(self.stream.as_raw_fd(), cx); + Poll::Pending + } + Err(e) => { + println!("read for fd {} done err", fd); + Poll::Ready(Err(e)) + } + } + } + } +} + +impl tokio::io::AsyncWrite for TcpStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + match self.stream.write(buf) { + Ok(n) => Poll::Ready(Ok(n)), + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + let reactor = get_reactor(); + reactor + .borrow_mut() + .modify_writable(self.stream.as_raw_fd(), cx); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<(), io::Error>> { + self.stream.shutdown(std::net::Shutdown::Write)?; + Poll::Ready(Ok(())) + } +} |
