I am new to rust and as learning project I want to create a tcp proxy. Tokio overwhelmed me and I could not find the proper documentation to understand how it works. So I tried to go with std net modules but I stuck at properly piping the traffic. I bound multiple listeners in their own threads and want to pipe the traffic forth and back like a proper tcp proxy would do. Unfortunately I seem to don't understand how that has to work in Rust. Can someone give me an example without 3rd party dependencies?
Here is my code. handle_connection gets called from accepting the connection.
fn forward(direction: &str, input: &mut BufReader<TcpStream>, output: &mut BufWriter<TcpStream>) {
loop {
let mut buffer = [0;1024];
debug!("{} Reading", direction);
match input.read(&mut buffer) {
Ok(bytes) => {
debug!("Read {} bytes", bytes);
if bytes < 1 {
break;
}
match output.write_all(&mut buffer) {
Ok(_) => {
debug!("Forwarded {:#?} bytes", bytes);
output.flush();
if bytes < 1024 {
break; // abort when everything is sent
}
},
Err(error) => panic!("Could not forward: {}", error)
}
},
Err(error) => panic!("Could not read: {}", error)
}
}
}
fn handle_connection(mapping: ProxyMapping, incoming: TcpStream) -> Result<(), String> {
info!("Incoming connection from {}", incoming.peer_addr().map_err(|e| format!("{}", e))?);
// Try to connect to target
let outgoing = TcpStream::connect_timeout(
&mapping.get_target_addr()?,
Duration::from_secs(1)
)?;
// forward tcp steam
debug!("Start sync");
let input_clone = input.try_clone()
.map_err(|error| format!("Couldn't clone {}", error))?;
let mut input_read = BufReader::new(input);
let mut input_write = BufWriter::new(input_clone);
let output_clone = output.try_clone()
.map_err(|error| format!("Couldn't clone {}", error))?;
let mut output_read = BufReader::new(output);
let mut output_write = BufWriter::new(output_clone);
debug!("spawn sync");
loop {
debug!("Forward");
forward("forward", &mut input_read, &mut output_write);
debug!("Backward");
forward("backward",&mut output_read, &mut input_write);
}
Ok(())
}
So far it proxies some data but for http requests it stops at backward reading as I don't know yet when I should shutdown a connection. After some time it just floods the log with read 0 bytes. I added a loop for the sync as for one tcp connection multiple packets could be sent. Any ideas how to improve this?
Related
I am trying to "connect" to multiple peers and "process" them in parallel. I have the following implementation
pub async fn process(addr: &str) {
// simulate processing by sleeping
tokio::time::sleep(core::time::Duration::from_millis(1)).await;
println!("processed {}", addr);
}
#[tokio::main]
async fn main() {
let mut peers = vec!["127.0.0.1", "139.48.123.146", "123.123.46.209"];
let conn_fut = futures::future::join_all(peers.iter().map(|peer| {
async move {
println!("connecting to {}", peer);
process(peer).await;
}
}));
// awaits all futures in parallell
conn_fut.await;
}
playground
output:
connecting to 127.0.0.1
connecting to 139.48.123.146
connecting to 123.123.46.209
processed 127.0.0.1
processed 139.48.123.146
processed 123.123.46.209
This is done in parallel since connecting to e.g. 139.48.123.146 can be done before processing 127.0.0.1.
In addition to this I want to add new connections to conn_fut while I am awaiting all futures in parallel. I think something like this would work:
pub async fn process(addr: &str) {
// simulate processing by sleeping
tokio::time::sleep(core::time::Duration::from_millis(1)).await;
println!("processed {}", addr);
}
#[tokio::main]
async fn main() {
let peers = vec!["127.0.0.1", "139.48.123.146", "123.123.46.209"];
let conn_fut = futures::future::join_all(peers.iter().map(|peer| {
async move {
println!("connecting to {}", peer);
process(peer).await;
}
}));
let new_peers = vec!["123.0.0.1", "124.0.0.1"];
let new_conn_fut = futures::future::join_all(new_peers.iter().map(|peer| {
async move {
println!("connecting to {}", peer);
process(peer).await;
}
}));
// awaits all futures in parallell
futures::future::join(conn_fut, new_conn_fut).await;
}
playground
output:
connecting to 127.0.0.1
connecting to 139.48.123.146
connecting to 123.123.46.209
connecting to 123.0.0.1
connecting to 124.0.0.1
processed 127.0.0.1
processed 139.48.123.146
processed 123.123.46.209
processed 123.0.0.1
processed 124.0.0.1
But, since I am retrieving the addresses from another asynchronous process I have this:
pub async fn process(addr: &str) {
// simulate processing by sleeping
tokio::time::sleep(core::time::Duration::from_millis(1)).await;
println!("processed {}", addr);
}
#[tokio::main]
async fn main() {
let peers = vec!["127.0.0.1", "139.48.123.146", "123.123.46.209"];
let conn_fut = futures::future::join_all(peers.iter().map(|peer| {
async move {
println!("connecting to {}", peer);
process(peer).await;
}
}));
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
let handle_conn_fut = async move {
while let Some(peer) = rx.recv().await {
println!("connecting to {}", peer);
process(peer).await;
}
};
let create_new_conn_fut = async move {
for peer in ["123.0.0.1", "124.0.0.1"] {
tx.send(peer).await.unwrap();
}
};
// does not await futures in parallel
futures::future::join3(conn_fut, handle_conn_fut, create_new_conn_fut).await;
}
playground
output:
connecting to 127.0.0.1
connecting to 139.48.123.146
connecting to 123.123.46.209
connecting to 123.0.0.1
processed 127.0.0.1
processed 139.48.123.146
processed 123.123.46.209
processed 123.0.0.1
connecting to 124.0.0.1
processed 124.0.0.1
Which does not process the new peers in parallel. I am not sure how write this, I have looked into using futures::stream which tokio::sync::mpsc::Receiver can be converted into by using tokio_stream::wrappers::ReceiverStream (?). The futures::stream::StreamExt trait implements for_each_concurrent which can process each element in parallel.
I am not sure what the best way is to go about this. Intuitively, I want to write something like this:
async fn connect(peers: futures::stream::Stream<&str>) {
peers.for_each_concurrent(|p| process(p).await);
}
#[tokio::main]
async fn main() {
let mut initial_peers = vec!["127.0.0.1", "139.48.123.123"];
let peers = futures::stream::iter(initial_peers)
let conn_fut = connect(peers)
let new_peers_fut = async move {
for new_peer in ["123.0.0.1", "124.0.0.1"] {
peers.push_back(new_peer)
}
}
futures::future::join(conn_fut, new_peers_fut).await
}
Is it possible to process the elements in a futures::stream::Stream in parallel and simultaneously append new elements to the end of the stream which in turn are also processed together with the other elements?
Solved this by converting conn_fut to a futures::stream and chained it with rx:
use futures::stream::StreamExt;
pub async fn process(addr: &str) {
// simulate processing by sleeping
tokio::time::sleep(core::time::Duration::from_millis(1)).await;
println!("processed {}", addr);
}
#[tokio::main]
async fn main() {
let peers = vec!["127.0.0.1", "139.48.123.146", "123.123.46.209"];
let peers = futures::stream::iter(peers);
let (tx, rx) = tokio::sync::mpsc::channel(100);
let rx = tokio_stream::wrappers::ReceiverStream::new(rx);
let rx = peers.chain(rx);
let handle_conn_fut = rx.for_each_concurrent(0,
|peer| async move {
println!("connecting to {}", peer);
process(peer).await;
}
);
let create_new_conn_fut = async move {
for peer in ["123.0.0.1", "124.0.0.1"] {
tx.send(peer).await.unwrap();
}
};
// awaits all futures in parallell
futures::future::join(handle_conn_fut, create_new_conn_fut).await;
}
output:
connecting to 127.0.0.1
connecting to 139.48.123.146
connecting to 123.123.46.209
connecting to 123.0.0.1
connecting to 124.0.0.1
processed 127.0.0.1
processed 139.48.123.146
processed 123.123.46.209
processed 123.0.0.1
processed 124.0.0.1
I'm working on a project to implement a distributed key value store in rust. I've made the server side code using Tokio's asynchronous runtime. I'm running into an issue where it seems my asynchronous code is blocking so when I have multiple connections to the server only one TcpStream is processed. I'm new to implementing async code, both in general and on rust, but I thought that other streams would be accepted and processed if there was no activity on a given tcp stream.
Is my understanding of async wrong or am I using tokio incorrectly?
This is my entry point:
use std::error::Error;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use env_logger;
use log::{debug, info};
use structopt::StructOpt;
use tokio::net::TcpListener;
extern crate blue;
use blue::ipc::message;
use blue::store::args;
use blue::store::cluster::{Cluster, NodeRole};
use blue::store::deserialize::deserialize_store;
use blue::store::handler::handle_stream;
use blue::store::wal::WriteAheadLog;
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
let opt = args::Opt::from_args();
let addr = SocketAddr::from_str(format!("{}:{}", opt.host, opt.port).as_str())?;
let role = NodeRole::from_str(opt.role.as_str()).unwrap();
let leader_addr = match role {
NodeRole::Leader => addr,
NodeRole::Follower => SocketAddr::from_str(opt.follow.unwrap().as_str())?,
};
let wal_name = addr.to_string().replace(".", "").replace(":", "");
let wal_full_name = format!("wal{}.log", wal_name);
let wal_path = PathBuf::from(wal_full_name);
let mut wal = match wal_path.exists() {
true => {
info!("Existing WAL found");
WriteAheadLog::open(&wal_path)?
}
false => {
info!("Creating WAL");
WriteAheadLog::new(&wal_path)?
}
};
debug!("WAL: {:?}", wal);
let store_name = addr.to_string().replace(".", "").replace(":", "");
let store_pth = format!("{}.pb", store_name);
let store_path = Path::new(&store_pth);
let mut store = match store_path.exists() {
true => deserialize_store(store_path)?,
false => message::Store::default(),
};
let listener = TcpListener::bind(addr).await?;
let cluster = Cluster::new(addr, &role, leader_addr, &mut wal, &mut store).await?;
let store_path = Arc::new(store_path);
let store = Arc::new(Mutex::new(store));
let wal = Arc::new(Mutex::new(wal));
let cluster = Arc::new(Mutex::new(cluster));
info!("Blue launched. Waiting for incoming connection");
loop {
let (stream, addr) = listener.accept().await?;
info!("Incoming request from {}", addr);
let store = Arc::clone(&store);
let store_path = Arc::clone(&store_path);
let wal = Arc::clone(&wal);
let cluster = Arc::clone(&cluster);
handle_stream(stream, store, store_path, wal, cluster, &role).await?;
}
}
Below is my handler (handle_stream from the above). I excluded all the handlers in match input as I didn't think they were necessary to prove the point (full code for that section is here: https://github.com/matthewmturner/Bradfield-Distributed-Systems/blob/main/blue/src/store/handler.rs if it actually helps).
Specifically the point that is blocking is the line let input = async_read_message::<message::Request>(&mut stream).await;
This is where the server is waiting for communication from either a client or another server in the cluster. The behavior I currently see is that after connecting to server with client the server doesn't receive any of the requests to add other nodes to the cluster - it only handles the client stream.
use std::io;
use std::net::{SocketAddr, TcpStream};
use std::path::Path;
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use log::{debug, error, info};
use serde_json::json;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream as asyncTcpStream;
use super::super::ipc::message;
use super::super::ipc::message::request::Command;
use super::super::ipc::receiver::async_read_message;
use super::super::ipc::sender::{async_send_message, send_message};
use super::cluster::{Cluster, NodeRole};
use super::serialize::persist_store;
use super::wal::WriteAheadLog;
// TODO: Why isnt async working? I.e. connecting servers after client is connected stays on client stream.
pub async fn handle_stream<'a>(
mut stream: asyncTcpStream,
store: Arc<Mutex<message::Store>>,
store_path: Arc<&Path>,
wal: Arc<Mutex<WriteAheadLog<'a>>>,
cluster: Arc<Mutex<Cluster>>,
role: &NodeRole,
) -> io::Result<()> {
loop {
info!("Handling stream: {:?}", stream);
let input = async_read_message::<message::Request>(&mut stream).await;
debug!("Input: {:?}", input);
match input {
...
}
}
}
This is the code for async_read_message
pub async fn async_read_message<M: Message + Default>(
stream: &mut asyncTcpStream,
) -> io::Result<M> {
let mut len_buf = [0u8; 4];
debug!("Reading message length");
stream.read_exact(&mut len_buf).await?;
let len = i32::from_le_bytes(len_buf);
let mut buf = vec![0u8; len as usize];
debug!("Reading message");
stream.read_exact(&mut buf).await?;
let user_input = M::decode(&mut buf.as_slice())?;
debug!("Received message: {:?}", user_input);
Ok(user_input)
}
Your problem lies with how you're handling messages after clients have connected:
handle_stream(stream, store, store_path, wal, cluster, &role).await?;
This .await means your listening loop will wait for handle_stream to return, but (making some assumptions) this function won't return until the client has disconnected. What you want is to tokio::spawn a new task that can run independently:
tokio::spawn(handle_stream(stream, store, store_path, wal, cluster, &role));
You may have to change some of your parameter types to avoid lifetimes; tokio::spawn requires 'static since the task's lifetime is decoupled from the scope where it was spawned.
I'm new to network programming and thread in Rust so I may be missing something obvious here. I've been following along with this trying to build a simple chat application. Only, he does it with the standard library and I'm trying to do it with tokio. The functionality is very simple: Client sends a message to Server, Server acknowledges it and sends it back to the Client. Here's my code for the client and server, stripped down as much as I can:
server.rs
#[tokio::main]
async fn main() {
let server = TcpListener::bind("127.0.0.1:7878").await.unwrap();
let mut clients = vec![];
let (tx, mut rx) = mpsc::channel(32);
loop {
if let Ok((socket, addr)) = server.accept().await {
let tx = tx.clone();
let (mut reader, writer) = split(socket);
clients.push(writer);
tokio::spawn(async move {
loop {
let mut buffer = vec![0; 1024];
reader.read(&mut buffer).await.unwrap();
//get message written by the client and print it
//then transmit it on the channel
let msg = buffer.into_iter().take_while(|&x| x != 0).collect::<Vec<_>>();
let msg = String::from_utf8(msg).expect("Invalid utf8 message");
println!("{}: {:?}", addr, msg);
match tx.send(msg).await {
Ok(_) => { ()}
Err(_) => { println!("Error");}
}
}
});
}
//write each message received back to its client
if let Some(msg) = rx.recv().await {
clients = clients.into_iter().filter_map(|mut x| {
println!("writing: {:?}", &msg);
x.write(&msg.clone().into_bytes());
Some(x)
}).collect::<Vec<_>>();
}
}
}
client.rs
#[tokio::main]
async fn main() {
let client = TcpStream::connect("127.0.0.1:7878").await.unwrap();
let (tx, mut rx) = mpsc::channel::<String>(32);
tokio::spawn(async move {
loop {
let mut buffer = vec![0; 1024];
// get message sent by the server and print it
match client.try_read(&mut buffer) {
Ok(_) => {
let msg = buffer.into_iter().take_while(|&x| x != 0).collect::<Vec<_>>();
println!("Received from server: {:?}", msg);
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
()
}
Err(_) => {
println!("Connection with server was severed");
break;
}
}
// get message transmitted from user input loop
// then write it to the server
match rx.try_recv() {
Ok(message) => {
let mut buffer = message.clone().into_bytes();
buffer.resize(1024, 0);
match client.try_write(&buffer) {
Ok(_) => { println!("Write successful");}
Err(_) => { println!("Write error");}
}
}
Err(TryRecvError::Empty) => (),
_ => break
}
}
} );
// user input loop here
// takes user message and transmits it on the channel
}
Sending to the server works fine, and the server appears to be successfully writing as indicated by its output:
127.0.0.1:55346: "test message"
writing: "test message"
The issue is the client never reads back the message from the server, instead getting WouldBlock errors every time it hits the match client.try_read(&mut buffer) block.
If I stop the server while keeping the client running, the client is suddenly flooded with successful reads of empty messages:
Received from server: []
Received from server: []
Received from server: []
Received from server: []
Received from server: []
Received from server: []
Received from server: []
Received from server: []
...
Can anyone tell me what's going on?
Here's what happens in your server:
Wait for a client to connect.
When the client is connected, spawn a background task to receive from the client.
Try to read from the channel, since it is very unlikely that the client has already sent anything at this point the channel is empty.
Loop → wait for another client to connect.
While waiting for another client, the background task receives the message from the first client and sends it to the channel, but the main task is blocked waiting for another client and never tries to read again from the channel.
Easiest way to get it to work is to get rid of the channel in the server and simply echo the message from the spawned task.
Another solution is to spawn an independent task to process the channel and write to the clients.
As for what happens when you kill the server: once the connection is lost attempting to read from the socket does not return an error but instead returns an empty buffer.
I'm sending large objects over a network and noticed that using a single network connection is significantly slower than using multiple.
Server code:
use async_std::{
io::{BufWriter, Write},
net::TcpListener,
prelude::*,
task,
};
use bench_utils::{end_timer, start_timer};
use futures::stream::{FuturesOrdered, StreamExt};
async fn send(buf: &[u8], writer: &mut (impl Write + Unpin)) {
// Send the message length
writer.write_all(&(buf.len() as u64).to_le_bytes()).await.unwrap();
// Send the rest of the message
writer.write_all(&buf).await.unwrap();
writer.flush().await.unwrap();
}
fn main() {
task::block_on(async move {
let listener = TcpListener::bind("0.0.0.0:8000").await.unwrap();
let mut incoming = listener.incoming();
let mut writers = Vec::with_capacity(16);
for _ in 0..16 {
let stream = incoming.next().await.unwrap().unwrap();
writers.push(BufWriter::new(stream))
};
let buf = vec![0u8; 1 << 30];
let send_time = start_timer!(|| "Sending buffer across 1 connection");
send(&buf, &mut writers[0]).await;
end_timer!(send_time);
let send_time = start_timer!(|| "Sending buffer across 16 connections");
writers
.iter_mut()
.zip(buf.chunks(buf.len() / 16))
.map(|(w, chunk)| {
send(chunk, w)
})
.collect::<FuturesOrdered<_>>()
.collect::<Vec<_>>()
.await;
end_timer!(send_time);
});
}
Client code:
use async_std::{
io::{BufReader, Read},
net::TcpStream,
prelude::*,
task,
};
use bench_utils::{end_timer, start_timer};
use futures::stream::{FuturesOrdered, StreamExt};
async fn recv(reader: &mut (impl Read + Unpin)) {
// Read the message length
let mut len_buf = [0u8; 8];
reader.read_exact(&mut len_buf).await.unwrap();
let len: u64 = u64::from_le_bytes(len_buf);
// Read the rest of the message
let mut buf = vec![0u8; usize::try_from(len).unwrap()];
reader.read_exact(&mut buf[..]).await.unwrap();
}
fn main() {
let host = &std::env::args().collect::<Vec<_>>()[1];
task::block_on(async move {
let mut readers = Vec::with_capacity(16);
for _ in 0..16 {
let stream = TcpStream::connect(host).await.unwrap();
readers.push(BufReader::new(stream));
}
let read_time = start_timer!(|| "Reading buffer from 1 connection");
recv(&mut readers[0]).await;
end_timer!(read_time);
let read_time = start_timer!(|| "Reading buffer from 16 connections");
readers
.iter_mut()
.map(|r| recv(r))
.collect::<FuturesOrdered<_>>()
.collect::<Vec<_>>()
.await;
end_timer!(read_time);
});
}
Server result:
Start: Sending buffer across 1 connection
End: Sending buffer across 1 connection....................................55.134s
Start: Sending buffer across 16 connections
End: Sending buffer across 16 connections..................................4.19s
Client result:
Start: Reading buffer from 1 connection
End: Reading buffer from 1 connection......................................55.396s
Start: Reading buffer from 16 connections
End: Reading buffer from 16 connections....................................3.914s
I am assuming that this difference is due to the sending connection having to wait for an ACK when the TCP buffer is filled (both machines have TCP window scaling enabled)? It doesn't appear that Rust provides an API to modify the size of these things.
Is there anyway to achieve similar throughput on a single connection? It seems annoying to have to pass around multiple since all of this is going through a single network interface anyways.
From the rust std net library:
let listener = TcpListener::bind(("127.0.0.1", port)).unwrap();
info!("Opened socket on localhost port {}", port);
// accept connections and process them serially
for stream in listener.incoming() {
break;
}
info!("closed socket");
How does one make the listener stop listening? It says in the API that when the listener is dropped, it stops. But how do we drop it if incoming() is a blocking call? Preferably without external crates like tokio/mio.
You'll want to put the TcpListener into non-blocking mode using the set_nonblocking() method, like so:
use std::io;
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:7878").unwrap();
listener.set_nonblocking(true).expect("Cannot set non-blocking");
for stream in listener.incoming() {
match stream {
Ok(s) => {
// do something with the TcpStream
handle_connection(s);
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
// Decide if we should exit
break;
// Decide if we should try to accept a connection again
continue;
}
Err(e) => panic!("encountered IO error: {}", e),
}
}
Instead of waiting for a connection, the incoming() call will immediately return a Result<> type. If Result is Ok(), then a connection was made and you can process it. If the Result is Err(WouldBlock), this isn't actually an error, there just wasn't a connection pending at the exact moment incoming() checked the socket.
Note that in the WouldBlock case, you may want to put a sleep() or something before continuing, otherwise your program will rapidly poll the incoming() function checking for a connection, resulting in high CPU usage.
Code example adapted from here
The standard library doesn't provide an API for this, but there are a few strategies you can use to work around it:
Shut down reads on the socket
You can use platform-specific APIs to shutdown reads on the socket which will cause the incoming iterator to return an error. You can then break out of handling connections when the error is received. For example, on a Unix system:
use std::net::TcpListener;
use std::os::unix::io::AsRawFd;
use std::thread;
let listener = TcpListener::bind("localhost:0")?;
let fd = listener.as_raw_fd();
let handle = thread::spawn(move || {
for connection in listener.incoming() {
match connection {
Ok(connection) => { /* handle connection */ }
Err(_) => break,
}
});
libc::shutdown(fd, libc::SHUT_RD);
handle.join();
Force the listener to wake up
Another (cross-platform) trick is to set a variable indicating that you want to stop listening, and then connect to the socket yourself to force the listening thread to wake up. When the listening thread wakes up, it checks the "stop listening" variable, and then exits cleanly if it's set.
use std::net::{TcpListener, TcpStream};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
let listener = TcpListener::bind("localhost:0")?;
let local_addr = listener.local_addr()?;
let shutdown = Arc::new(AtomicBool::new(false));
let server_shutdown = shutdown.clone();
let handle = thread::spawn(move || {
for connection in listener.incoming() {
if server_shutdown.load(Ordering::Relaxed) {
return;
}
match connection {
Ok(connection) => { /* handle connection */ }
Err(_) => break,
}
}
});
shutdown.store(true, Ordering::Relaxed);
let _ = TcpStream::connect(local_addr);
handle.join().unwrap();
You can poll your socket with an eventfd, which used for signaling.
I wrote a helper for this.
let shutdown = EventFd::new();
let listener = TcpListener::bind("0.0.0.0:12345")?;
let incoming = CancellableIncoming::new(&listener, &shutdown);
for stream in incoming {
// Your logic
}
// While in other thread
shutdown.add(1); // Light the shutdown signal, now your incoming loop exits gracefully.
use nix;
use nix::poll::{poll, PollFd, PollFlags};
use nix::sys::eventfd::{eventfd, EfdFlags};
use nix::unistd::{close, write};
use std;
use std::net::{TcpListener, TcpStream};
use std::os::unix::io::{AsRawFd, RawFd};
pub struct EventFd {
fd: RawFd,
}
impl EventFd {
pub fn new() -> Self {
EventFd {
fd: eventfd(0, EfdFlags::empty()).unwrap(),
}
}
pub fn add(&self, v: i64) -> nix::Result<usize> {
let b = v.to_le_bytes();
write(self.fd, &b)
}
}
impl AsRawFd for EventFd {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for EventFd {
fn drop(&mut self) {
let _ = close(self.fd);
}
}
// -----
//
pub struct CancellableIncoming<'a> {
listener: &'a TcpListener,
eventfd: &'a EventFd,
}
impl<'a> CancellableIncoming<'a> {
pub fn new(listener: &'a TcpListener, eventfd: &'a EventFd) -> Self {
Self { listener, eventfd }
}
}
impl<'a> Iterator for CancellableIncoming<'a> {
type Item = std::io::Result<TcpStream>;
fn next(&mut self) -> Option<std::io::Result<TcpStream>> {
use nix::errno::Errno;
let fd = self.listener.as_raw_fd();
let evfd = self.eventfd.as_raw_fd();
let mut poll_fds = vec![
PollFd::new(fd, PollFlags::POLLIN),
PollFd::new(evfd, PollFlags::POLLIN),
];
loop {
match poll(&mut poll_fds, -1) {
Ok(_) => break,
Err(nix::Error::Sys(Errno::EINTR)) => continue,
_ => panic!("Error polling"),
}
}
if poll_fds[0].revents().unwrap() == PollFlags::POLLIN {
Some(self.listener.accept().map(|p| p.0))
} else if poll_fds[1].revents().unwrap() == PollFlags::POLLIN {
None
} else {
panic!("Can't be!");
}
}
}