Rust Async doesn't execute in parallel for sockets - asynchronous

I'm trying to send and receive simultaneously to a multicast IP with Rust.
use futures::executor::block_on;
use async_std::task;
use std::{net::{UdpSocket, Ipv4Addr}, time::{Duration, Instant}};
fn main() {
let future = async_main();
block_on(future);
}
async fn async_main() {
let mut socket = UdpSocket::bind("0.0.0.0:8888").unwrap();
let multi_addr = Ipv4Addr::new(234, 2, 2, 2);
let inter = Ipv4Addr::new(0,0,0,0);
socket.join_multicast_v4(&multi_addr,&inter);
let async_one = first(&socket);
let async_two = second(&socket);
futures::join!(async_one, async_two);
}
async fn first(socket: &std::net::UdpSocket) {
let mut buf = [0u8; 65535];
let now = Instant::now();
loop {
if now.elapsed().as_secs() > 10 { break; }
let (amt, src) = socket.recv_from(&mut buf).unwrap();
println!("received {} bytes from {:?}", amt, src);
}
}
async fn second(socket: &std::net::UdpSocket) {
let now = Instant::now();
loop {
if now.elapsed().as_secs() > 10 { break; }
socket.send_to(String::from("h").as_bytes(), "234.2.2.2:8888").unwrap();
}
}
The issue with this is first it runs the receive function and then it runs the send function, it never sends and receives simultaneously. With Golang I can do this with Goroutines but I'm finding this quite difficult in Rust.

I'm not very experienced with async in Rust, but your first() and second() functions don't appear to have any asynchronous calls in them -- in other words, there are not any calls that use .await. My understanding is that if nothing is awaited, then the functions will run synchronously, and I believe you get a compiler warning about it as well.
It doesn't look like std::net::UdpSocket provides any async methods that can be awaited, and you need to use async_std::net::UdpSocket instead.

Related

How to await `JoinHandle`s and update `JoinHandle`s at the same time?

Is it possible to both read a stream of Futures from a set of JoinHandle<()> tasks and update that set of tasks with new tasks at the same time?
I currently have a Service that runs some long tasks. Only thing is, I would actually like to (if possible) add new tasks in at the same time -- via a flag sent by some type of Receiver channel (not shown below to keep things simple).
Given that in Service::run handles becomes owned by that function, I would lean towards "no", this is not possible. Is this true? If this isn't possible given my setup, is there some way I could tweak the code below to make this possible?
I read in this answer that wrapping HashMap in an Option allows me to use .take() in Service::run since the value needs to be owned in order to call .into_values(). However, problem with this is that .take() consumes the value in the Mutex, leaving None in its wake.
Here is my minimal reproducible example (did not compile this, but should give the idea):
use tokio::{sleep, time::Duration, task::JoinHandle};
use async_std::{Mutex, Arc};
use futures::{
stream::{FuturesUnordered, StreamExt},
Future,
};
type Handles = Arc<Mutex<Option<HashMap<String, JoinHandle<()>>>>>;
fn a_task() -> impl Future<Output = ()> {
async move {
sleep(Duration::from_secs(3)).await;
}
}
fn the_update_task(handles: Handles) -> impl Future<Output = ()> {
async move {
// would like to update `handles` here as I get new data from a channel
// calling .take() in Service::run nukes my handles here :(
}
}
struct Service {
handles: Handles,
}
impl Service {
fn new() -> Self {
let handles = Arc::new(Mutex::new(Some(HashMap::default())));
let handle = tokio::spawn(the_update_task(handles.clone());
Self { handles }
}
async fn add_a_task(&mut self, id: String) {
let handle = tokio::spawn(a_task());
self.handles.lock().await.as_mut().unwrap().insert(id, handle);
}
async fn run(self) {
let Service { handles, .. } = self;
let mut futs = FuturesUnordered::from_iter(
handles.lock().await.take().unwrap().into_values()
);
while let Some(fut) = futs.next().await {
info!("I completed a task! fut:?}");
}
}
}
#[tokio::main]
async fn main() {
let mut srvc = Service::new();
srvc.add_task("1".to_string()).await;
srvc.add_task("2".to_string()).await;
let handle = tokio::spawn(srv.run());
handle.await;
}
I have tried
Using Arc(Mutex(HashMap))
Using Arc(Mutex(Option(HashMap)))
I seem to arrive always at the same conclusion:
I cannot both own handles in Service::run and update handles (even a copy/reference) from other part of the code
Just answering my own question here with the help of #user1937198's comment.
The solution was to update a reference to the FuturesUnordered directly with new tasks, as opposed to being concerned with handles. This simplifies things quite a bit.
use tokio::{sleep, time::Duration, task::JoinHandle};
use async_std::{Mutex, Arc};
use futures::{
stream::{FuturesUnordered, StreamExt},
Future,
};
fn a_task() -> impl Future<Output = ()> {
async move {
sleep(Duration::from_secs(3)).await;
}
}
fn the_update_task(futs: Arc<Mutex<FuturesUnordered>>) -> impl Future<Output = ()> {
async move {
// Just push another task
let fut = tokio::spawn(a_task());
futs.lock().await.push(fut);
}
}
struct Service {
handles: HashMap<String, JoinHandle<()>>,
}
impl Service {
fn new() -> Self {
let handles = HashMap::default();
Self { handles }
}
async fn add_a_task(&mut self, id: String) {
let handle = tokio::spawn(a_task());
self.handles.insert(id, handle);
}
async fn run(self) {
let Service { handles, .. } = self;
let futs = Arc::new(Mutex::new(FuturesUnordered::from_iter(handles.into_values())));
tokio::spawn(the_update_task(futs.clone())).await.unwrap();
while let Some(fut) = futs.lock().await.next().await {
info!("I completed a task! fut:?}");
}
}
}
#[tokio::main]
async fn main() {
let mut srvc = Service::new();
srvc.add_task("1".to_string()).await;
srvc.add_task("2".to_string()).await;
let handle = tokio::spawn(srv.run());
handle.await;
}

Parallel work stealing in arbitrary order in Rust

I'm trying to write a parallel data loader for deep learning in Rust. The task is to write an iterator that under the hood does the following
Reads files from disk and applies some compute-heavy preprocessing to them, the result is generally a numeric array (or multiple)
Groups the results of the previous step into batches of size B and "collates" them - this generally means just concatenating the arrays - moderately compute heavy
Yields the results from step 2.
Step 1 can be both IO and compute bound, depending on network latency, size of files and complexity of preprocessing. It has to be run in parallel by many workers. Step 2 should be off the main thread but likely doesn't need a pool of workers. Step 3 happens on main thread (exposed to Python).
The reason I write it in Rust is that Python offers two options: pure Python implementation shipped with PyTorch, based on multiprocessing, which is somewhat slow but very flexible (arbitrary user-defined data preprocessing and batching) and C++ implementation shipped with Tensorflow, which is assembled by the user from a set of predefined primitives. The latter is substantially faster but too restrictive for the kinds of data processing I wish to do. I expect that Rust will give me the speed of Tensorflow with flexibility of arbitrary code as in PyTorch.
My question is purely about the way to implement parallelism. The ideal setup is to have N workers for step 1) -> channel -> worker for step 2) -> channel -> step 3. Because the iterator object may be dropped at any time, there is a strict requirement to be able to terminate the whole scheme after Drop. On the other hand, there is the flexibility of loading the files in an arbitrary order: for example if the batch size B == 16 and max_n_threads == 32, it is perfectly fine to start 32 workers and yield the first batch containing the 16 examples which happen to return first. This can be exploited for speed.
My naive implementation creates the DataLoader in 3 steps:
Create a n_working: Arc<AtomicUsize> to control the number of worker threads active and should_shutdown: Arc<AtomicBool> to signal shutdown (when Drop is called)
Create a thread responsible for maintaining the pool. It spins on n_working < max_n_threads and keeps spawning worker threads which terminate on should_shutdown, otherwise fetch a single example, send it down the worker->batcher channel and decrement n_working
Create a batching thread which polls the worker->batcher channel, upon receiving B objects concatenates them into a batch and sends down the batcher->yielder channel
#[pyclass]
struct DataLoader {
collate_worker: Option<thread::JoinHandle<()>>,
example_worker: Option<thread::JoinHandle<()>>,
should_shut_down: Arc<AtomicBool>,
receiver: Receiver<Batch>,
length: usize,
}
impl DataLoader {
fn new(
dataset: Dataset,
batch_size: usize,
capacity: usize,
) -> Self {
let n_batches = dataset.len() / batch_size;
let max_n_threads = capacity * batch_size;
let (example_sender, collate_receiver) = bounded((batch_size - 1) * capacity);
let should_shut_down = Arc::new(AtomicBool::new(false));
let shutdown_flag = should_shut_down.clone();
let example_worker = thread::spawn(move || {
rayon::scope_fifo(|s| {
let dataset = &dataset;
let n_working = Arc::new(AtomicUsize::new(0));
let mut current_index = 0;
while current_index < n_batches * batch_size {
if n_working.load(Ordering::Relaxed) == max_n_threads {
continue;
}
if shutdown_flag.load(Ordering::Relaxed) {
break;
}
let index = current_index.clone();
let sender = example_sender.clone();
let counter = n_working.clone();
let shutdown_flag = shutdown_flag.clone();
s.spawn_fifo(move |_s| {
let example = dataset.get_example(index);
if !shutdown_flag.load(Ordering::Relaxed) {
_ = sender.send(example);
} // if we should shut down, skip sending
counter.fetch_sub(1, Ordering::Relaxed);
});
current_index += 1;
n_working.fetch_add(1, Ordering::Relaxed);
};
});
});
let (batch_sender, final_receiver) = bounded(capacity);
let shutdown_flag = should_shut_down.clone();
let collate_worker = thread::spawn(move || {
'outer: loop {
let mut batch = vec![];
for _ in 0..batch_size {
if let Ok(example) = collate_receiver.recv() {
batch.push(example);
} else {
break 'outer;
}
};
let collated = collate(batch);
if shutdown_flag.load(Ordering::Relaxed) {
break; // skip sending
}
_ = batch_sender.send(collated);
};
});
Self {
collate_worker: Some(collate_worker),
example_worker: Some(example_worker),
should_shut_down: should_shut_down,
receiver: final_receiver,
length: n_batches,
}
}
}
#[pymethods]
impl DataLoader {
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> { slf }
fn __next__(&mut self) -> Option<Batch> {
self.receiver.recv().ok()
}
fn __len__(&self) -> usize {
self.length
}
}
impl Drop for DataLoader {
fn drop(&mut self) {
self.should_shut_down.store(true, Ordering::Relaxed);
if self.collate_worker.take().unwrap().join().is_err() {
println!("Panic in collate worker");
};
if self.example_worker.take().unwrap().join().is_err() {
println!("Panic in example_worker");
};
println!("dropped the dataloader");
}
}
This implementation works and roughly matches the performance of PyTorch but provides no significant speedup. I don't know where to look for improvements, but I imagine it would help to have the thing load-balance automatically in a work-stealing way and to flexibly spawn workers depending on the proportion of IO and compute time. I am also expecting performance issues due to the spinning pool manager and likely corner cases in my handling of Drop.
My question is how to best approach the problem. I am generally unsure if this should be tackled with parallel crates like rayon, async crates like tokio, or a mix of both. I also have the hunch my implementation could be much simpler with the correct use of their combinators/higher order APIs. I tried with rayon but I couldn't get a solution which doesn't wastefully enforce the original sequential returning order and respects the Drop requirement.
Okay I think I've figured out a solution for you that uses rayon parallel iterators.
The trick is to use Results in the rayon iterators, and return Err if the cancellation flag is set.
I first created a utility type to create a cancellable thread in which you can execute rayon iterators. You use it by passing in the thread closure which takes the atomic cancellation token as a parameter. Then you have to check if the cancellation token is true, and if so, exit early.
use std::sync::Arc;
use std::sync::atomic::{Ordering, AtomicBool};
use std::thread::JoinHandle;
fn collate(batch: &[Computed]) -> Batch {
batch.iter().map(|&x| i128::from(x)).sum()
}
#[derive(Debug)]
struct Cancelled;
struct CancellableThread<Output: Send + 'static> {
cancel_token: Arc<AtomicBool>,
thread: Option<JoinHandle<Result<Output, Cancelled>>>,
}
impl<Output: Send + 'static> CancellableThread<Output> {
fn new<F: FnOnce(Arc<AtomicBool>) -> Result<Output, Cancelled> + Send + 'static>(init: F) -> Self {
let cancel_token = Arc::new(AtomicBool::new(false));
let thread_cancel_token = Arc::clone(&cancel_token);
CancellableThread {
thread: Some(std::thread::spawn(move || init(thread_cancel_token))),
cancel_token,
}
}
fn output(mut self) -> Output {
self.thread.take().unwrap().join().unwrap().unwrap()
}
}
impl<Output: Send + 'static> Drop for CancellableThread<Output> {
fn drop(&mut self) {
self.cancel_token.store(true, Ordering::Relaxed);
if let Some(thread) = self.thread.take() {
let _ = thread.join().unwrap();
}
}
}
I found it useful to create a closure that returns a Result<(), Cancelled> so I could use the try operator (?) to exit early.
CancellableThread::new(move |cancel_token| {
let cancelled = || if cancel_token.load(Ordering::Relaxed) {
Err(Cancelled)
} else {
Ok(())
};
loop {
// was the thread dropped?
// if so, stop what we're doing
cancelled?;
// do stuff and
// eventually return a result
}
});
I then used that CancellableThread abstraction in the DataLoader. No need to create a special Drop impl for it, because by default, it will call drop on each field anyways, which will handle the cancellation.
type Data = Vec<u8>;
type Dataset = Vec<Data>;
type Computed = u64;
type Batch = i128;
use rayon::prelude::*;
use crossbeam::channel::{unbounded, Receiver};
struct DataLoader {
example_worker: CancellableThread<()>,
collate_worker: CancellableThread<()>,
receiver: Receiver<Batch>,
length: usize,
}
I used unbounded channels, as it was one less thing to bother about. It shouldn't be hard to switch to bounded ones instead.
impl DataLoader {
fn new(dataset: Dataset, batch_size: usize) -> Self {
let (example_sender, collate_receiver) = unbounded();
let (batch_sender, final_receiver) = unbounded();
I'm not sure if you can always guarantee that the number of items in your dataset will be a multiple of the batch_size, so I decided to handle that explicitly.
let length = if dataset.len() % batch_size == 0 {
dataset.len() / batch_size
} else {
dataset.len() / batch_size + 1
};
I created the collating worker first, though that may not be necessary. As you can see, I had to duplicate a little bit to handle partial batches.
let collate_worker = CancellableThread::new(move |cancel_token| {
let cancelled = || if cancel_token.load(Ordering::Relaxed) {
Err(Cancelled)
} else {
Ok(())
};
'outer: loop {
let mut batch = Vec::with_capacity(batch_size);
for _ in 0..batch_size {
cancelled()?;
if let Ok(data) = collate_receiver.recv() {
batch.push(data);
} else {
if !batch.is_empty() {
// handle the last batch, if there
// weren't enough items to fill it
let collated = collate(&batch);
cancelled()?;
batch_sender.send(collated).unwrap();
}
break 'outer;
}
}
let collated = collate(&batch);
cancelled()?;
batch_sender.send(collated).unwrap();
}
Ok(())
});
The example worker is where things are really made much simpler, because we can just use rayon parallel iterators. As you can see, we check for cancellation before each heavy computation.
let example_worker = CancellableThread::new(move |cancel_token| {
let cancelled = || if cancel_token.load(Ordering::Relaxed) {
Err(Cancelled)
} else {
Ok(())
};
let heavy_compute = |data: Data| -> Result<Computed, Cancelled> {
cancelled()?;
Ok(data.iter().map(|&x| u64::from(x)).product())
};
dataset
.into_par_iter()
.map(heavy_compute)
.try_for_each(|computed| {
example_sender.send(computed?).unwrap();
Ok(())
})
});
Then we just construct the DataLoader. You can see the Python impl is identical:
DataLoader {
example_worker,
collate_worker,
receiver: final_receiver,
length,
}
}
}
// #[pymethods]
impl DataLoader {
fn __iter__(this: Self /* PyRef<Self> */) -> Self /* PyRef<Self> */ { this }
fn __next__(&mut self) -> Option<Batch> {
self.receiver.recv().ok()
}
fn __len__(&self) -> usize {
self.length
}
}
playground

Async recursive function that takes a mutex

How do you create an async recursive function that takes a mutex? Rust claims that this code holds a mutex across an await point. However, the value is dropped before the .await.
#[async_recursion]
async fn f(mutex: &Arc<Mutex<u128>>) {
let mut unwrapped = mutex.lock().unwrap();
*unwrapped += 1;
let value = *unwrapped;
drop(unwrapped);
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
if value < 100 {
f(mutex);
}
}
Error
future cannot be sent between threads safely
within `impl futures::Future<Output = ()>`, the trait `std::marker::Send` is not implemented for `std::sync::MutexGuard<'_, u128>`
required for the cast to the object type `dyn futures::Future<Output = ()> + std::marker::Send`rustc
lib.rs(251, 65): future is not `Send` as this value is used across an await
In this case, you can restructure the code to make it so unwrapped can't be used across an await:
let value = {
let mut unwrapped = mutex.lock().unwrap();
*unwrapped += 1;
*unwrapped
};
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
if value < 100 {
f(mutex);
}
If you weren't able to do this, then you'd need to make it so you don't return a Future that implements Send. The async_recursion docs specify an option you can pass to the macro to disable the Send bound it adds:
#[async_recursion(?Send)]
async fn f(mutex: &Arc<Mutex<u128>>) {
...
(playground)
You wouldn't be able to send such a Future across threads though.

BufReader::lines() over TcpStream stops iteration

I have a reasonably simple code using TcpStream and SslStream around it, reading the socket line by line with BufReader. Sometime the iterator just stops returning any data with Ok(0):
let mut stream = TcpStream::connect((self.config.host.as_ref(), self.config.port)).unwrap();
if self.config.ssl {
let context = ssl::SslContext::new(ssl::SslMethod::Tlsv1_2).unwrap();
let mut stream = ssl::SslStream::connect(&context, stream).unwrap();
self.stream = Some(ssl::MaybeSslStream::Ssl(stream));
} else {
self.stream = Some(ssl::MaybeSslStream::Normal(stream));
}
...
let read_stream = clone_stream(&self.stream);
let line_reader = BufReader::new(read_stream);
for line in line_reader.lines() {
match line {
Ok(line) => {
...
}
Err(e) => panic!("line read failed: {}", e),
}
}
println!("lines out, {:?}", self.stream);
The loop just stops randomly, as far as I can see and there's no reason to believe socket was closed server-side. Calling self.stream.as_mut().unwrap().read_to_end(&mut buf) after loop ended returns Ok(0).
Any advice on how this is expected to be handled? I don't get any Err so I can assume socket is still alive but then I can't read anything from it. What is the current state of the socket and how should I proceed?
PS: I'm providing the implementation of clone_stream as the reference, as advied by commenter.
fn clone_stream(stream: &Option<ssl::MaybeSslStream<TcpStream>>) -> ssl::MaybeSslStream<TcpStream> {
if let &Some(ref s) = stream {
match s {
&ssl::MaybeSslStream::Ssl(ref s) => ssl::MaybeSslStream::Ssl(s.try_clone().unwrap()),
&ssl::MaybeSslStream::Normal(ref s) => ssl::MaybeSslStream::Normal(s.try_clone().unwrap()),
}
} else {
panic!();
}
}
Surprisingly, this was a "default timeout" on client I guess (socket was moving into CLOSE_WAIT state).
I fixed it by first adding:
stream.set_read_timeout(Some(Duration::new(60*5, 0)));
stream.set_write_timeout(Some(Duration::new(60*5, 0)));
that made the iterator fail with ErrorKind::WouldBlock on timeout, at which point I added a code to send a ping packet over the wire, the next iteration worked exactly as expected.

Owned pointer in graph structure

With the generous help of the rust community I managed to get the base of a topological data structure assembled using managed pointers. This came together rather nicely and I was pretty excited about Rust in general. Then I read this post (which seems like a reasonable plan) and it inspired me to back track and try to re-assemble it using only owned pointers if possible.
This is the working version using managed pointers:
struct Dart<T> {
alpha: ~[#mut Dart<T>],
embed: ~[#mut T],
tagged: bool
}
impl<T> Dart<T> {
pub fn new(dim: uint) -> #mut Dart<T> {
let mut dart = #mut Dart{alpha: ~[], embed: ~[], tagged: false};
dart.alpha = vec::from_elem(dim, dart);
return dart;
}
pub fn get_dim(&self) -> uint {
return self.alpha.len();
}
pub fn traverse(#mut self, invs: &[uint], f: &fn(&Dart<T>)) {
let dim = self.get_dim();
for invs.each |i| {if *i >= dim {return}}; //test bounds on invs vec
if invs.len() == 2 {
let spread:int = int::abs(invs[1] as int - invs[0] as int);
if spread == 1 { //simple loop
let mut dart = self;
let mut i = invs[0];
while !dart.tagged {
dart.tagged = true;
f(dart);
dart = dart.alpha[i];
if i == invs[0] {i = invs[1];}
else {i == invs[0];}
} }
// else if spread == 2 { // max 4 cells traversed
// }
}
else {
let mut stack = ~[self];
self.tagged = true;
while !stack.is_empty() {
let mut dart = stack.pop();
f(dart);
for invs.each |i| {
if !dart.alpha[*i].tagged {
dart.alpha[*i].tagged = true;
stack.push(dart);
} } } } } }
After a few hours of chasing lifetime errors I have come to the conclusion that this may not even be possible with owned pointers due to the cyclic nature (without tying the knot as I was warned). My feeble attempt at this is below. My question, is this structure possible to implement without resorting to managed pointers? And if not, is the code above considered reasonably "rusty"? (idiomatic rust). Thanks.
struct GMap<'self,T> {
dim: uint,
darts: ~[~Dart<'self,T>]
}
struct Dart<'self,T> {
alpha: ~[&'self mut Dart<'self, T>],
embed: ~[&'self mut T],
tagged: bool
}
impl<'self, T> GMap<'self, T> {
pub fn new_dart(&'self mut self) {
let mut dart = ~Dart{alpha: ~[], embed: ~[], tagged: false};
let dartRef: &'self mut Dart<'self, T> = dart;
dartRef.alpha = vec::from_elem(self.dim, copy dartRef);
self.darts.push(dart);
}
}
I'm pretty sure that using &mut pointers is impossible, since one can only have one such pointer in existence at a time, e.g.:
fn main() {
let mut i = 0;
let a = &mut i;
let b = &mut i;
}
and-mut.rs:4:12: 4:18 error: cannot borrow `i` as mutable more than once at at a time
and-mut.rs:4 let b = &mut i;
^~~~~~
and-mut.rs:3:12: 3:18 note: second borrow of `i` as mutable occurs here
and-mut.rs:3 let a = &mut i;
^~~~~~
error: aborting due to previous error
One could get around the borrow checker unsafely, by either storing unsafe pointer to the memory (ptr::to_mut_unsafe_ptr), or indices into the darts member of GMap. Essentially, storing a single reference to the memory (in self.darts) and all operations have to go through it.
This might look like:
impl<'self, T> GMap<'self, T> {
pub fn new_dart(&'self mut self) {
let ind = self.darts.len();
self.darts.push(~Dart{alpha: vec::from_elem(self.dim, ind), embed: ~[], tagged: false});
}
}
traverse would need to change to either be a method on GMap (e.g. fn(&mut self, node_ind: uint, invs: &[uint], f: &fn(&Dart<T>))), or at least take a GMap type.
(On an entirely different note, there is library support for external iterators, which are far more composable than the internal iterators (the ones that take a closure). So defining one of these for traverse may (or may not) make using it nicer.)

Resources