diff options
| author | ihciah <[email protected]> | 2021-11-28 01:46:23 +0800 |
|---|---|---|
| committer | ihciah <[email protected]> | 2021-11-29 16:05:45 +0800 |
| commit | 54d9885220d2e0cd0167f6cbb10c7b0d9e762df2 (patch) | |
| tree | 392992a5ee3d531018bf55c12305e0ebc2ce984a /src/tcp.rs | |
init
Diffstat (limited to 'src/tcp.rs')
| -rw-r--r-- | src/tcp.rs | 162 |
1 files changed, 162 insertions, 0 deletions
diff --git a/src/tcp.rs b/src/tcp.rs new file mode 100644 index 0000000..46ef06a --- /dev/null +++ b/src/tcp.rs @@ -0,0 +1,162 @@ +use std::{ + cell::RefCell, + io::{self, Read, Write}, + net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream, ToSocketAddrs}, + os::unix::prelude::AsRawFd, + rc::{Rc, Weak}, + task::Poll, +}; + +use futures::Stream; +use socket2::{Domain, Protocol, Socket, Type}; + +use crate::{reactor::get_reactor, reactor::Reactor}; + +#[derive(Debug)] +pub struct TcpListener { + reactor: Weak<RefCell<Reactor>>, + listener: StdTcpListener, +} + +impl TcpListener { + pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> { + let addr = addr + .to_socket_addrs()? + .next() + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?; + + let domain = if addr.is_ipv6() { + Domain::IPV6 + } else { + Domain::IPV4 + }; + let sk = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + let addr = socket2::SockAddr::from(addr); + sk.set_reuse_address(true)?; + sk.bind(&addr)?; + sk.listen(1024)?; + + // add fd to reactor + let reactor = get_reactor(); + reactor.borrow_mut().add(sk.as_raw_fd()); + + println!("tcp bind with fd {}", sk.as_raw_fd()); + Ok(Self { + reactor: Rc::downgrade(&reactor), + listener: sk.into(), + }) + } +} + +impl Stream for TcpListener { + type Item = std::io::Result<(TcpStream, SocketAddr)>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Option<Self::Item>> { + match self.listener.accept() { + Ok((stream, addr)) => Poll::Ready(Some(Ok((stream.into(), addr)))), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + // modify reactor to register interest + let reactor = self.reactor.upgrade().unwrap(); + reactor + .borrow_mut() + .modify_readable(self.listener.as_raw_fd(), cx); + Poll::Pending + } + Err(e) => std::task::Poll::Ready(Some(Err(e))), + } + } +} + +#[derive(Debug)] +pub struct TcpStream { + stream: StdTcpStream, +} + +impl From<StdTcpStream> for TcpStream { + fn from(stream: StdTcpStream) -> Self { + let reactor = get_reactor(); + reactor.borrow_mut().add(stream.as_raw_fd()); + Self { stream } + } +} + +impl Drop for TcpStream { + fn drop(&mut self) { + println!("drop"); + let reactor = get_reactor(); + reactor.borrow_mut().delete(self.stream.as_raw_fd()); + } +} + +impl tokio::io::AsyncRead for TcpStream { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll<io::Result<()>> { + let fd = self.stream.as_raw_fd(); + unsafe { + let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]); + println!("read for fd {}", fd); + match self.stream.read(b) { + Ok(n) => { + println!("read for fd {} done, {}", fd, n); + buf.assume_init(n); + buf.advance(n); + Poll::Ready(Ok(())) + } + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + println!("read for fd {} done WouldBlock", fd); + // modify reactor to register interest + let reactor = get_reactor(); + reactor + .borrow_mut() + .modify_readable(self.stream.as_raw_fd(), cx); + Poll::Pending + } + Err(e) => { + println!("read for fd {} done err", fd); + Poll::Ready(Err(e)) + } + } + } + } +} + +impl tokio::io::AsyncWrite for TcpStream { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll<Result<usize, io::Error>> { + match self.stream.write(buf) { + Ok(n) => Poll::Ready(Ok(n)), + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + let reactor = get_reactor(); + reactor + .borrow_mut() + .modify_writable(self.stream.as_raw_fd(), cx); + Poll::Pending + } + Err(e) => Poll::Ready(Err(e)), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<(), io::Error>> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Result<(), io::Error>> { + self.stream.shutdown(std::net::Shutdown::Write)?; + Poll::Ready(Ok(())) + } +} |
