How to compose two async function in Rust - asynchronous

I am trying to write a higher order function that compose two async function.
i am basically looking for the async version this
fn compose<A, B, C, G, F>(f: F, g: G) -> impl Fn(A) -> C
where
F: Fn(A) -> B,
G: Fn(B) -> C,
{
move |x| g(f(x))
}
This my attempt so far.
fn compose_future<A, B, C, G, F>(f: F, g: G) -> (impl Fn(A) -> impl Future<C>)
where
F: Fn(A) -> impl Future<B>,
G: Fn(B) -> impl Future<C>,
{
move |x| async { g(f(x).await).await }
}
and i get the following error
error[E0562]: `impl Trait` not allowed outside of function and inherent method return types
--> src\channel.rs:13:17
|
13 | F: Fn(A) -> impl Future<B>,
| ^^^^^^^^^^^^^^
Is it possible to accomplish this ?

I'm not sure it is possible to do it that simple with impl Trait-s. The one solution I can come up with is old-fashioned future types usage without async-await feature. TLDR: full playground. Async-await uses a generators which internally holds a state machine, so we need to define it manually:
enum State<In, F, FutOutF, G, FutOutG> {
Initial(In, F, G), // Out composed type created
FirstAwait(FutOutF, G), // Composed type waits for the first future
SecondAwait(FutOutG), // and for the second
// here can be a `Completed` state, but it simpler
// to handle it with `Option<..>` in our future itself
}
Then define a composed type itself:
struct Compose<In, Out, F, FutOutF, G, FutOutG> {
state: Option<State<In, F, FutOutF, G, FutOutG>>,
_t: PhantomData<Out>,
}
// And "entry-point" would be something like that:
fn compose_fut<In, Out, F, FutOutF, G, FutOutG>(
i: In,
f: F,
g: G,
) -> Compose<In, Out, F, FutOutF, G, FutOutG> {
Compose {
state: Some(State::Initial(i, f, g)),
_t: PhantomData,
}
}
Then comes the most complex part - impl Future itself, here a base impl declaration without implementation:
impl<In, Mid, Out, F, FutOutF, G, FutOutG> Future for Compose<In, Out, F, FutOutF, G, FutOutG>
where
FutOutF: Future<Output = Mid>,
F: FnOnce(In) -> FutOutF,
FutOutG: Future<Output = Out>,
G: FnOnce(Mid) -> FutOutG,
{
type Output = Out;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
// here comes the magic
}
}
Values transformed as following: In -> Mid -> Out, where F and G are our composed functions and their output are FutOutF and FutOutG accordingly. And finally Future::poll implementation:
let this = unsafe { self.get_unchecked_mut() };
let state = this.state.take();
match state {
None => Poll::Pending, // invalid state
Some(State::Initial(i, f, g)) => {
let fut = f(i);
this.state = Some(State::FirstAwait(fut, g));
cx.waker().wake_by_ref();
Poll::Pending
}
Some(State::FirstAwait(mut fut, g)) => {
let val = match unsafe { Pin::new_unchecked(&mut fut) }.poll(cx) {
Poll::Ready(v) => v,
Poll::Pending => {
this.state = Some(State::FirstAwait(fut, g));
return Poll::Pending;
}
};
let fut = g(val);
this.state = Some(State::SecondAwait(fut));
cx.waker().wake_by_ref();
Poll::Pending
}
Some(State::SecondAwait(mut fut)) => {
match unsafe { Pin::new_unchecked(&mut fut) }.poll(cx) {
Poll::Ready(v) => Poll::Ready(v),
Poll::Pending => {
this.state = Some(State::SecondAwait(fut));
Poll::Pending
}
}
}
}
I avoid any library to make it "plain", usually unsafe parts are handled with pin-project or futures::pin_mut. The state management is fairly complex, so I suggest to re-check the implementation, there might be mistakes.

Related

Borrow error when attempting recursion on HashMap, where each value needs a reference to the map

I'm currently have an issue regarding Rust's borrowing policy.
I have a HashMap of structs 'Value', each which contains a list of keys to other Values in HashMap. I am attempting to recursively call a function on these Values which requires a reference to the HashMap.
use std::collections::HashMap;
struct Value {
val: f64,
prevs: Vec<usize>,
sum: f64,
}
impl Value {
pub fn new(i: usize) -> Value {
let mut res = Value {
val: 0.1,
prevs: Vec::new(),
sum: 0.0,
};
for j in 0..i {
res.prevs.push(j);
}
res
}
pub fn evaluate(&mut self, pool: &mut HashMap<usize, Value>) -> f64 {
self.sum = self.val;
for i in &self.prevs {
let prev = pool.get_mut(i).unwrap();
self.sum += prev.evaluate(pool);
}
self.sum
}
}
fn main() {
let mut hm: HashMap<usize, Value> = HashMap::new();
for i in 0..10 {
hm.insert(i, Value::new(i));
}
println!("{}", hm.get(&9).unwrap().evaluate(&mut hm));
}
Error:
error[E0499]: cannot borrow `*pool` as mutable more than once at a time
--> src/lib.rs:25:39
|
24 | let prev = pool.get_mut(i).unwrap();
| --------------- first mutable borrow occurs here
25 | self.sum += prev.evaluate(pool);
| -------- ^^^^ second mutable borrow occurs here
| |
| first borrow later used by call
Playground
Context
I'm attempting to calculate the output of a neural network (usually done via feedforward) by starting from the output, and recursively evaluating each node, as a weighted sum of the nodes connected to it, with an unpredictable topology. This requires each node having a list of input_nodes, which are keys to a node pool HashMap.
Below is a sample with a few variants:
Non-performant and probably deadlock-prone but compiling version using Arc<Mutex>
High-performance version using Vec and split_at_mut
Highly unsafe, UB and "against-all-good-practices" version using Vec and pointers. At least evaluates to the same number, wanted to add for performance comparison.
#![feature(test)]
extern crate test;
use std::{collections::HashMap, sync::{Arc, Mutex}};
#[derive(Debug)]
struct Value {
val: f64,
prevs: Vec<usize>,
sum: f64,
}
impl Value {
pub fn new(i: usize) -> Value {
let mut res = Value {
val: 0.1,
prevs: Vec::new(),
sum: 0.0,
};
for j in 0..i {
res.prevs.push(j);
}
res
}
pub fn evaluate(&mut self, pool: &mut HashMap<usize, Arc<Mutex<Value>>>) -> f64 {
self.sum = self.val;
for i in &self.prevs {
let val = pool.get_mut(i).unwrap().clone();
self.sum += val.lock().unwrap().evaluate(pool);
}
self.sum
}
pub fn evaluate_split(&mut self, pool: &mut [Value]) -> f64 {
self.sum = self.val;
for i in &self.prevs {
let (hm, val) = pool.split_at_mut(*i);
self.sum += val[0].evaluate_split(hm);
}
self.sum
}
// OBS! Don't do this, horribly unsafe and wrong
pub unsafe fn evaluate_unsafe(&mut self, pool: *const Value, pool_len: usize) -> f64 {
let pool = std::slice::from_raw_parts_mut(pool as *mut Value, pool_len);
self.sum = self.val;
for i in &self.prevs {
let (pool_ptr, pool_len) = (pool.as_ptr(), pool.len());
self.sum += pool[*i].evaluate_unsafe(pool_ptr, pool_len);
}
self.sum
}
}
fn main() {
// arcmutex
let mut hm: HashMap<usize, Arc<Mutex<Value>>> = HashMap::new();
for i in 0..10 {
hm.insert(i, Arc::new(Mutex::new(Value::new(i))));
}
let val = hm.get(&9).unwrap().clone();
assert_eq!(val.lock().unwrap().evaluate(&mut hm), 51.2);
// split vec
let mut hm = (0..10).map(|v| {
Value::new(v)
}).collect::<Vec<_>>();
let (hm, val) = hm.split_at_mut(9);
assert_eq!((hm.len(), val.len()), (9, 1));
assert_eq!(val[0].evaluate_split(hm), 51.2);
}
#[cfg(test)]
mod tests {
use test::bench;
use super::*;
#[bench]
fn bench_arc_mutex(b: &mut bench::Bencher) {
let mut hm: HashMap<usize, Arc<Mutex<Value>>> = HashMap::new();
for i in 0..10 {
hm.insert(i, Arc::new(Mutex::new(Value::new(i))));
}
b.iter(|| {
let val = hm.get(&9).unwrap().clone();
assert_eq!(val.lock().unwrap().evaluate(&mut hm), 51.2);
});
}
#[bench]
fn bench_split(b: &mut bench::Bencher) {
let mut hm = (0..10).map(|v| {
Value::new(v)
}).collect::<Vec<_>>();
b.iter(|| {
let (hm, val) = hm.split_at_mut(9);
assert_eq!(val[0].evaluate_split(hm), 51.2);
});
}
#[bench]
fn bench_unsafe(b: &mut bench::Bencher) {
let mut hm = (0..10).map(|v| {
Value::new(v)
}).collect::<Vec<_>>();
b.iter(|| {
// OBS! Don't do this, horribly unsafe and wrong
let (hm_ptr, hm_len) = (hm.as_ptr(), hm.len());
let val = &mut hm[9];
assert_eq!(unsafe { val.evaluate_unsafe(hm_ptr, hm_len) }, 51.2);
});
}
}
cargo bench results to:
running 3 tests
test tests::bench_arc_mutex ... bench: 13,249 ns/iter (+/- 367)
test tests::bench_split ... bench: 1,974 ns/iter (+/- 70)
test tests::bench_unsafe ... bench: 1,989 ns/iter (+/- 62)
Also, have a look at https://rust-unofficial.github.io/too-many-lists/index.html

How do I calculate a multiple factorial using num_bigint in Rust?

I am trying to calculate a factorial of a factorial in Rust using the Num-BigInt library. I've gotten to the point where I can calculate a factorial:
use num_bigint::BigUint;
use num_traits::{One, Zero, FromPrimitive};
fn factorial(n: usize) -> BigUint {
let mut f: BigUint = One::one();
for i in 1..(n+1) {
let bu: BigUint = FromPrimitive::from_usize(i).unwrap();
f = f * bu;
}
f
}
pub fn main() {
println!("Starting calculation...");
println!("{}", factorial(5));
}
I want to do a double factorial, like:
pub fn main() {
println!("Starting calculation...");
println!("{}", factorial(factorial(5)));
}
However, this throws the following error because the data types are different:
error[E0308]: mismatched types
--> src/main.rs:16:30
|
16 | println!("{}", factorial(factorial(5)));
| ^^^^^^^^^^^^ expected `usize`, found struct `BigUint`
How can I repeat this function using BigUint instead of usize?
The problem is that you first want to use a usize and then a BigUint as your function parameter, while your parameter is set to be a usize.
To fix this you should make your factorial function generic and then only allow those types that make sense for your particular method.
I took your example and extended it to allow for all unsigned integer types:
use num_bigint::BigUint;
use num_traits::{One, FromPrimitive, Unsigned};
use num::{Integer, NumCast, ToPrimitive};
fn factorial<N: Unsigned + Integer + ToPrimitive>(n: N) -> BigUint {
let mut f: BigUint = One::one();
let end: usize = NumCast::from(n).unwrap();
for i in 1..(end + 1) {
let bu: BigUint = FromPrimitive::from_usize(i).unwrap();
f = f * bu;
}
f
}
pub fn main() {
println!("Starting calculation...");
println!("{}", factorial(factorial(5 as u32)));
}
This will however lead to not allowing factorial(5), because 5 is treated as an i32 by default, which can be signed. If you wish to allow for signed types and instead fail at runtime you can do something like that:
use num_bigint::BigUint;
use num_traits::{One, FromPrimitive};
use num::{Integer, NumCast, ToPrimitive};
fn factorial<N: Integer + ToPrimitive>(n: N) -> BigUint {
let mut f: BigUint = One::one();
let end: usize = NumCast::from(n).expect("Number too big or negative number used.");
for i in 1..(end + 1) {
let bu: BigUint = FromPrimitive::from_usize(i).unwrap();
f = f * bu;
}
f
}
pub fn main() {
println!("Starting calculation...");
println!("{}", factorial(factorial(5)));
}

How can I add functions with different arguments and return types to a vector?

I'm trying add functions with different arguments to a vector.
fn f1() {
println!("Hello, World!");
}
fn f2(s: &str) -> String {
String::from(s)
}
fn f3(i: i32) {
println!("{}", i);
}
fn main() {
let v = vec![f1, f3, f2];
}
But this gives the error:
error[E0308]: mismatched types
--> src/main.rs:12:22
|
12 | let v = vec![f1, f3, f2];
| ^^ incorrect number of function parameters
|
= note: expected type `fn() {f1}`
found fn item `fn(i32) {f3}`
Is there any way I can make this work?
The error happens because vectors are only meant to hold homogeneous data, i.e. every element of a vector has to be the same type. To solve this, you could for example use a vector of enums:
enum E {
F1(fn()),
F2(fn(&str) -> String),
F3(fn(i: i32)),
}
fn f1() {
println!("Hello, World!");
}
fn f2(s: &str) -> String {
String::from(s)
}
fn f3(i: i32) {
println!("{}", i);
}
fn main() {
let v = vec![E::F1(f1), E::F3(f3), E::F2(f2)];
for func in v {
match func {
E::F1(f) => f(),
E::F2(f) => println!("{}", f("foo")),
E::F3(f) => f(2),
}
}
}
Output
Hello, World!
2
foo
Or you could use a container made specifically for storing heterogeneous data, aka a tuple:
fn main() {
let v = (f1, f3, f2);
v.0();
v.1(2);
println!("{}", v.2("foo"));
}
Output
Hello, World!
2
foo

How do I create a macro that takes a function with multiple parameters and supplies the first argument for that function?

I want to be able to create a higher-order function (called g) that takes in a function (called f). g should pass in the first parameter to f and return a new function.
The use case is that I want to initiate a database connection in g and pass it functions that accept a database connection.
fn f1(a: i32, b: String) -> String {
b
}
fn f2(a: i32, c: i64, d: i16) -> i32 {
1000
}
fn g<T>(f: fn(a: i32, arbitrary_arguments_type) -> T) -> fn(arbitrary_arguments_type) -> T {
move |arbitrary_arguments| f(1, arbitrary_arguments)
}
fn main() {
g(f1)("hello".to_string());
g(f2)(10, 11);
}
How do I create a macro that takes in as an argument a function with a more than 1 parameter, where first parameter is of a certain type, and supplies that argument for that first function?
The specific question I'm having is how do I create a macro that takes in as an argument a function with a more than 1 parameter, where first parameter is of a certain type, supplies that argument for that first function.
Macros (even procedural macros) operate on syntax trees, so they can't change their behaviour based on semantics, including types and function arity. That means you'd have to have a different macro for each possible number of arguments. For example:
macro_rules! curry1 {
($func: ident, $($arg: expr),*) => {
|a| $func($($arg),*, a)
}
}
macro_rules! curry2 {
($func: ident, $($arg: expr),*) => {
|a, b| $func($($arg),*, a, b)
}
}
macro_rules! curry3 {
($func: ident, $($arg: expr),*) => {
|a, b, c| $func($($arg),*, a, b, c)
}
}
Which would be used like this:
fn f(a: i32, b: i32, c: i32) -> i32 {
a + b + c
}
fn main() {
// requires 2 extra args
let f_2 = curry2!(f, 2);
// requires 1 extra arg
let f_2_1 = curry1!(f, 2, 1);
println!("{}", f(2, 1, 3)); // 6
println!("{}", f_2(1, 3)); // 6
println!("{}", f_2_1(3)); // 6
}

Most efficient way to fill a vector from back to front

I am trying to populate a vector with a sequence of values. In order to calculate the first value I need to calculate the second value, which depends on the third value etc etc.
let mut bxs = Vec::with_capacity(n);
for x in info {
let b = match bxs.last() {
Some(bx) => union(&bx, &x.bbox),
None => x.bbox.clone(),
};
bxs.push(b);
}
bxs.reverse();
Currently I just fill the vector front to back using v.push(x) and then reverse the vector using v.reverse(). Is there a way to do this in a single pass?
Is there a way to do this in a single pass?
If you don't mind adapting the vector, it's relatively easy.
struct RevVec<T> {
data: Vec<T>,
}
impl<T> RevVec<T> {
fn push_front(&mut self, t: T) { self.data.push(t); }
}
impl<T> Index<usize> for RevVec<T> {
type Output = T;
fn index(&self, index: usize) -> &T {
&self.data[self.len() - index - 1]
}
}
impl<T> IndexMut<usize> for RevVec<T> {
fn index_mut(&mut self, index: usize) -> &mut T {
let len = self.len();
&mut self.data[len - index - 1]
}
}
The solution using unsafe is below. The unsafe version is slightly more than 2x as fast as the safe version using reverse(). The idea is to use Vec::with_capacity(usize) to allocate the vector, then use ptr::write(dst: *mut T, src: T) to write the elements into the vector back to front. offset(self, count: isize) -> *const T is used to calculate the offset into the vector.
extern crate time;
use std::fmt::Debug;
use std::ptr;
use time::PreciseTime;
fn scanl<T, F>(u : &Vec<T>, f : F) -> Vec<T>
where T : Clone,
F : Fn(&T, &T) -> T {
let mut v = Vec::with_capacity(u.len());
for x in u.iter().rev() {
let b = match v.last() {
None => (*x).clone(),
Some(y) => f(x, &y),
};
v.push(b);
}
v.reverse();
return v;
}
fn unsafe_scanl<T, F>(u : &Vec<T> , f : F) -> Vec<T>
where T : Clone + Debug,
F : Fn(&T, &T) -> T {
unsafe {
let mut v : Vec<T> = Vec::with_capacity(u.len());
let cap = v.capacity();
let p = v.as_mut_ptr();
match u.last() {
None => return v,
Some(x) => ptr::write(p.offset((u.len()-1) as isize), x.clone()),
};
for i in (0..u.len()-1).rev() {
ptr::write(p.offset(i as isize), f(v.get_unchecked(i+1), u.get_unchecked(i)));
}
Vec::set_len(&mut v, cap);
return v;
}
}
pub fn bench_scanl() {
let lo : u64 = 0;
let hi : u64 = 1000000;
let v : Vec<u64> = (lo..hi).collect();
let start = PreciseTime::now();
let u = scanl(&v, |x, y| x + y);
let end= PreciseTime::now();
println!("{:?}\n in {}", u.len(), start.to(end));
let start2 = PreciseTime::now();
let u = unsafe_scanl(&v, |x, y| x + y);
let end2 = PreciseTime::now();
println!("2){:?}\n in {}", u.len(), start2.to(end2));
}

Resources