1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
|
use std::{
cell::RefCell,
io::{self, Read, Write},
net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream, ToSocketAddrs},
os::unix::prelude::AsRawFd,
rc::{Rc, Weak},
task::Poll,
};
use futures::Stream;
use socket2::{Domain, Protocol, Socket, Type};
use crate::{reactor::get_reactor, reactor::Reactor};
/// TCP 监听器
#[derive(Debug)]
pub struct TcpListener {
reactor: Weak<RefCell<Reactor>>, // reactor
listener: StdTcpListener, // 标准库的 TcpListener | 包装一层
}
impl TcpListener {
/// 绑定地址并返回一个 `TcpListener` 实例
pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self, io::Error> {
// 解析地址
let addr = addr
.to_socket_addrs()?
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?;
// 创建 socket
let domain = if addr.is_ipv6() {
Domain::IPV6
} else {
Domain::IPV4
};
let sk = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
// 绑定地址并监听
let addr = socket2::SockAddr::from(addr);
sk.set_reuse_address(true)?;
sk.bind(&addr)?;
sk.listen(1024)?;
// 将 fd 添加到 reactor 中
let reactor = get_reactor();
reactor.borrow_mut().add(sk.as_raw_fd());
println!("tcp bind with fd {}", sk.as_raw_fd());
Ok(Self {
reactor: Rc::downgrade(&reactor),
listener: sk.into(),
})
}
}
//Stream 流
impl Stream for TcpListener { //TcpStream 和 TcpListener 在这链接
type Item = std::io::Result<(TcpStream, SocketAddr)>; //
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
match self.listener.accept() {
Ok((stream, addr)) => Poll::Ready(Some(Ok((stream.into(), addr)))),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { // 继续阻塞
// 修改反应器以注册感兴趣的事件
let reactor = self.reactor.upgrade().unwrap();
reactor
.borrow_mut()
.modify_readable(self.listener.as_raw_fd(), cx); // 可读事件
Poll::Pending
}
Err(e) => std::task::Poll::Ready(Some(Err(e))),
}
}
}
/// TCP 流
#[derive(Debug)]
pub struct TcpStream {
stream: StdTcpStream, // 标准库的 TcpStream | 包装一层
}
impl From<StdTcpStream> for TcpStream {
// 从标准库的 TcpStream 转换为自定义的 TcpStream
fn from(stream: StdTcpStream) -> Self {
let reactor: Rc<RefCell<Reactor>> = get_reactor(); // 获取 reactor
reactor.borrow_mut().add(stream.as_raw_fd()); // 将 fd 添加到 reactor
Self { stream } // 返回包装后的 TcpStream
}
}
impl Drop for TcpStream {
fn drop(&mut self) {
// 可变引用
println!("drop");
let reactor = get_reactor(); // 获取 reactor
reactor.borrow_mut().delete(self.stream.as_raw_fd()); // 将 fd 从 reactor 中删除
}
}
// 为 TcpStream 实现 tokio::io::AsyncRead
impl tokio::io::AsyncRead for TcpStream {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let fd = self.stream.as_raw_fd(); // 获取 stream 对应的 fd
unsafe {
// 将 ReadBuf 转换为 [u8] , stream.read 需要
let b = &mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]);
println!("read for fd {}", fd);
match self.stream.read(b) {
Ok(n) => { // 读取成功
println!("read for fd {} done, {}", fd, n);
buf.assume_init(n); // 初始化 n 个字节
buf.advance(n); // 指针前进 n 个字节
Poll::Ready(Ok(())) // 返回结果
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 读取失败,且错误类型为 WouldBlock
println!("read for fd {} done WouldBlock", fd);
// 修改反应器以注册感兴趣的事件
let reactor = get_reactor(); // 获取 reactor
reactor
.borrow_mut()
.modify_readable(self.stream.as_raw_fd(), cx); // 注册到 reactor 可读事件
Poll::Pending // 等待
}
Err(e) => {
println!("read for fd {} done err", fd);
Poll::Ready(Err(e)) // 结束(错误)
}
}
}
}
}
impl tokio::io::AsyncWrite for TcpStream {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
match self.stream.write(buf) {
Ok(n) => Poll::Ready(Ok(n)), // 写入成功
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { // 写入失败,且错误类型为 WouldBlock
let reactor: Rc<RefCell<Reactor>> = get_reactor();
reactor
.borrow_mut()
.modify_writable(self.stream.as_raw_fd(), cx); // 注册可写事件
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)), // 写入失败,返回结果(错误信息)
}
}
fn poll_flush( // 或许是仅占位吧.
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown( // 关闭时,将 stream 的写入关闭
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), io::Error>> {
self.stream.shutdown(std::net::Shutdown::Write)?; // 关闭写入,出错则传递错误
Poll::Ready(Ok(()))
}
}
|