Suddenly discover that recursive call of suspend function takes more time then calling the same function but without suspend modifier, so please consider the code snippet below (basic Fibonacci series calculation):
suspend fun asyncFibonacci(n: Int): Long = when {
n <= -2 -> asyncFibonacci(n + 2) - asyncFibonacci(n + 1)
n == -1 -> 1
n == 0 -> 0
n == 1 -> 1
n >= 2 -> asyncFibonacci(n - 1) + asyncFibonacci(n - 2)
else -> throw IllegalArgumentException()
}
If I call this function and measure its execution time with code below:
fun main(args: Array<String>) {
val totalElapsedTime = measureTimeMillis {
val nFibonacci = 40
val deferredFirstResult: Deferred<Long> = async {
asyncProfile("fibonacci") { asyncFibonacci(nFibonacci) } as Long
}
val deferredSecondResult: Deferred<Long> = async {
asyncProfile("fibonacci") { asyncFibonacci(nFibonacci) } as Long
}
val firstResult: Long = runBlocking { deferredFirstResult.await() }
val secondResult: Long = runBlocking { deferredSecondResult.await() }
val superSum = secondResult + firstResult
println("${thread()} - Sum of two $nFibonacci'th fibonacci numbers: $superSum")
}
println("${thread()} - Total elapsed time: $totalElapsedTime millis")
}
I observe further results:
commonPool-worker-2:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Start calculation...
commonPool-worker-2:fibonacci - Finish calculation...
commonPool-worker-2:fibonacci - Elapsed time: 7704 millis
commonPool-worker-1:fibonacci - Finish calculation...
commonPool-worker-1:fibonacci - Elapsed time: 7741 millis
main - Sum of two 40'th fibonacci numbers: 204668310
main - Total elapsed time: 7816 millis
But if I remove suspend modifier from asyncFibonacci function, I'll have this result:
commonPool-worker-2:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Start calculation...
commonPool-worker-1:fibonacci - Finish calculation...
commonPool-worker-1:fibonacci - Elapsed time: 1179 millis
commonPool-worker-2:fibonacci - Finish calculation...
commonPool-worker-2:fibonacci - Elapsed time: 1201 millis
main - Sum of two 40'th fibonacci numbers: 204668310
main - Total elapsed time: 1250 millis
I know that's better to rewrite such a function with tailrec it will increase its execution time apx. almost in 100 times, but anyway, what this suspend key word does that decrease execution speed from 1 second to 8 seconds?
Is it totally stupid idea to mark recursive functions with suspend?
As an introductory comment, your testing code setup is too complex. This much simpler code achieves the same in terms of stressing suspend fun recursion:
fun main(args: Array<String>) {
launch(Unconfined) {
val nFibonacci = 37
var sum = 0L
(1..1_000).forEach {
val took = measureTimeMillis {
sum += suspendFibonacci(nFibonacci)
}
println("Sum is $sum, took $took ms")
}
}
}
suspend fun suspendFibonacci(n: Int): Long {
return when {
n >= 2 -> suspendFibonacci(n - 1) + suspendFibonacci(n - 2)
n == 0 -> 0
n == 1 -> 1
else -> throw IllegalArgumentException()
}
}
I tried to reproduce its performance by writing a plain function that approximates the kinds of things the suspend function must do to achieve suspendability:
val COROUTINE_SUSPENDED = Any()
fun fakeSuspendFibonacci(n: Int, inCont: Continuation<Unit>): Any? {
val cont = if (inCont is MyCont && inCont.label and Integer.MIN_VALUE != 0) {
inCont.label -= Integer.MIN_VALUE
inCont
} else MyCont(inCont)
val suspended = COROUTINE_SUSPENDED
loop# while (true) {
when (cont.label) {
0 -> {
when {
n >= 2 -> {
cont.n = n
cont.label = 1
val f1 = fakeSuspendFibonacci(n - 1, cont)!!
if (f1 === suspended) {
return f1
}
cont.data = f1
continue#loop
}
n == 1 || n == 0 -> return n.toLong()
else -> throw IllegalArgumentException("Negative input not allowed")
}
}
1 -> {
cont.label = 2
cont.f1 = cont.data as Long
val f2 = fakeSuspendFibonacci(cont.n - 2, cont)!!
if (f2 === suspended) {
return f2
}
cont.data = f2
continue#loop
}
2 -> {
val f2 = cont.data as Long
return cont.f1 + f2
}
else -> throw AssertionError("Invalid continuation label ${cont.label}")
}
}
}
class MyCont(val completion: Continuation<Unit>) : Continuation<Unit> {
var label = 0
var data: Any? = null
var n: Int = 0
var f1: Long = 0
override val context: CoroutineContext get() = TODO("not implemented")
override fun resumeWithException(exception: Throwable) = TODO("not implemented")
override fun resume(value: Unit) = TODO("not implemented")
}
You have to invoke this one with
sum += fakeSuspendFibonacci(nFibonacci, InitialCont()) as Long
where InitialCont is
class InitialCont : Continuation<Unit> {
override val context: CoroutineContext get() = TODO("not implemented")
override fun resumeWithException(exception: Throwable) = TODO("not implemented")
override fun resume(value: Unit) = TODO("not implemented")
}
Basically, to compile a suspend fun the compiler has to turn its body into a state machine. Each invocation must also create an object to hold the machine's state. When you resume, the state object tells which state handler to go to. The above still isn't all there is to it, the real code is even more complex.
In intepreted mode (java -Xint), I get almost the same performance as the actual suspend fun, and it is less than twice as fast than the real one with JIT enabled. By comparison, the "direct" function implementation is about 10 times as fast. That means that the code shown explains a good part of the overhead of suspendability.
The problem lies in the Java bytecode generated from the suspend function. While a non-suspend function just generates bytecode like we'd expect it:
public static final long asyncFibonacci(int n) {
long var10000;
if (n <= -2) {
var10000 = asyncFibonacci(n + 2) - asyncFibonacci(n + 1);
} else if (n == -1) {
var10000 = 1L;
} else if (n == 0) {
var10000 = 0L;
} else if (n == 1) {
var10000 = 1L;
} else {
if (n < 2) {
throw (Throwable)(new IllegalArgumentException());
}
var10000 = asyncFibonacci(n - 1) + asyncFibonacci(n - 2);
}
return var10000;
}
When you add the suspend keyword, the decompiled Java source code is 165 lines - so a lot larger. You can view the bytecode and the decompiled Java code in IntelliJ by going to Tools -> Kotlin -> Show Kotlin bytecode (and then click Decompile on top of the page). While it's not easy to tell what exactly the Kotlin compiler is doing in the function, it looks like it's doing a whole lot of coroutine status checking - which kind of makes sense given that a coroutine can be suspended at any time.
So as a conclusion I'd say that every suspend method call is a lot more heavy than a non-suspend call. This does not only apply to recursive functions, but probably has the worst result on them.
Is it totally stupid idea to mark recursive functions with suspend?
Unless you have a very good reason to do so - Yes
Related
I'm trying to write a function to find the lowest number that all integers between 1 and 20 divide. (Let's call this Condition D)
Here's my solution, which is somehow exceeding the call stack size limit.
function findSmallest(num){
var count = 2
while (count<21){
count++
if (num % count !== 0){
// exit the loop
return findSmallest(num++)
}
}
return num
}
console.log(findSmallest(20))
Somewhere my reasoning on this is faulty but here's how I see it (please correct me where I'm wrong):
Calling this function with a number N that doesn't meet Condition D will result in the function being called again with N + 1. Eventually, when it reaches a number M that should satisfy Condition D, the while loop runs all the way through and the number M is returned by the function and there are no more recursive calls.
But I get this error on running it:
function findSmallest(num){
^
RangeError: Maximum call stack size exceeded
I know errors like this are almost always due to recursive functions not reaching a base case. Is this the problem here, and if so, where's the problem?
I found two bugs.
in your while loop, the value of count is 3 to 21.
the value of num is changed in loop. num++ should be num + 1
However, even if these bugs are fixed, the error will not be solved.
The answer is 232792560.
This recursion depth is too large, so stack memory exhausted.
For example, this code causes same error.
function foo (num) {
if (num === 0) return
else foo(num - 1)
}
foo(232792560)
Coding without recursion can avoid errors.
Your problem is that you enter the recursion more than 200 million times (plus the bug spotted in the previous answer). The number you are looking for is the multiple of all prime numbers times their max occurrences in each number of the defined range. So here is your solution:
function findSmallestDivisible(n) {
if(n < 2 || n > 100) {
throw "Numbers between 2 and 100 please";
}
var arr = new Array(n), res = 2;
arr[0] = 1;
arr[1] = 2;
for(var i = 2; i < arr.length; i++) {
arr[i] = fix(i, arr);
res *= arr[i];
}
return res;
}
function fix(idx, arr) {
var res = idx + 1;
for(var i = 1; i < idx; i++) {
if((res % arr[i]) == 0) {
res /= arr[i];
}
}
return res;
}
https://jsfiddle.net/7ewkeamL/
I am working on the third Project Euler problem:
fn main() {
println!("{}", p3());
}
fn p3() -> u64 {
let divs = divisors(1, 600851475143, vec![]);
let mut max = 0;
for x in divs {
if prime(x, 0, false) && x > max {
max = x
}
}
max
}
fn divisors(i: u64, n: u64, div: Vec<u64>) -> Vec<u64> {
let mut temp = div;
if i * i > n {
temp
} else {
if n % i == 0 {
temp.push(i);
temp.push(n / i);
}
divisors(i + 2, n, temp)
}
}
fn prime(n: u64, i: u64, skip: bool) -> bool {
if !skip {
if n == 2 || n == 3 {
true
} else if n % 3 == 0 || n % 2 == 0 {
false
} else {
prime(n, 5, true)
}
} else {
if i * i > n {
true
} else if n % i == 0 || n % (i + 2) == 0 {
false
} else {
prime(n, i + 6, true)
}
}
}
The value 600851475143 is the value that is at some point causing it to overflow. If I replace that with any value that is in the 1010 order of magnitude or less, it returns an answer. While keeping it as a recursive solution, is there any way to either:
Increase the stack size?
Optimize my code so it doesn't return a fatal runtime: stack overflow error?
I know this can be done iteratively, but I'd prefer to not do that.
A vector containing 600 * 109 u64s means you'll need 4.8 terabytes of RAM or swapspace.
I'm sure you don't need that for this problem, you're missing some knowledge of math here: scanning till the square root of the 600851475143 will be sufficient. You may also speed up the program by using the Sieve of Eratosthenes.
Project Euler is nice to sharpen your math skills, but it doesn't help you with any programming language in particular. For learning Rust I started with Exercism.
Performing some optimizations, such as going just up to the square root of the number when checking for its factors and for whether it's a prime, I've got:
fn is_prime(n: i64) -> bool {
let float_input = n as f64;
let upper_bound = float_input.sqrt() as i64;
for x in 2..upper_bound + 1 {
if n % x == 0 {
return false;
}
}
return true;
}
fn get_factors(n: i64) -> Vec<i64> {
let mut factors: Vec<i64> = Vec::new();
let float_input = n as f64;
let upper_bound = float_input.sqrt() as i64;
for x in 1..upper_bound + 1 {
if n % x == 0 {
factors.push(x);
factors.push(n / x);
}
}
factors
}
fn get_prime_factors(n: i64) -> Vec<i64> {
get_factors(n)
.into_iter()
.filter(|&x| is_prime(x))
.collect::<Vec<i64>>()
}
fn main() {
if let Some(max) = get_prime_factors(600851475143).iter().max() {
println!("{:?}", max);
}
}
On my machine, this code runs very fast with no overflow.
./problem003 0.03s user 0.00s system 90% cpu 0.037 total
If you really don't want the iterative version:
First, make sure that you compile with optimizations (rustc -O or cargo --release). Without it there's no chance for TCO in Rust. Your divisors function is tail-recursive, but it seems that moving this Vec up and down the recursion stack is confusing enough for LLVM to miss that fact. We can help the compiler a little, by using just a reference here:
fn divisors(i: u64, n: u64, mut div: Vec<u64>) -> Vec<u64> {
divisors_(i, n, &mut div);
div
}
fn divisors_(i: u64, n: u64, div: &mut Vec<u64>) {
if i * i > n {
} else {
if n % i == 0 {
div.push(i);
div.push(n / i);
}
divisors_(i + 2, n, div)
}
}
On my machine that changes make the code no longer segfault.
If you want to increase the stack size anyway, you should run your function in a separate thread with increased stack size (using std::thread::Builder::stack_size)
Rust has reserved the become keyword for guaranteed tail recursion,
so maybe in the future you'll just need to add one keyword to your code to make it work.
I've several confusion about tail recursion as follows:
some of the recursion functions are void functions for example,
// Prints the given number of stars on the console.
// Assumes n >= 1.
void printStars(int n) {
if (n == 1) {
// n == 1, base case
cout << "*";
} else {
// n > 1, recursive case
cout << "*"; // print one star myself
printStars(n - 1); // recursion to do the rest
}
}
and another example:
// Prints the given integer's binary representation.
// Precondition: n >= 0
void printBinary(int n) {
if (n < 2) {
// base case; same as base 10
cout << n;
} else {
// recursive case; break number apart
printBinary(n / 2);
printBinary(n % 2);
}
}
As we know by definition tail recursion should return some value from tail call. But for void functions it does not return any value. By intinction I think they are tail recursion but I am not confident about it.
another question is that, if a recursion function has several logical end, should tail recursion come at all logical ends or just one of the logical ends? I saw someone argued that only one of the logical ends is OK, but I am not sure about that. Here's my example:
// Returns base ^ exp.
// Precondition: exp >= 0
int power(int base, int exp) {
if (exp < 0) {
throw "illegal negative exponent";
} else if (exp == 0) {
// base case; any number to 0th power is 1
return 1;
} else if (exp % 2 == 0) {
// recursive case 1: x^y = (x^2)^(y/2)
return power(base * base, exp / 2);
} else {
// recursive case 2: x^y = x * x^(y-1)
return base * power(base, exp - 1);
}
}
Here we have logical end as tail recursion and another one that is not tail recursion. Do you think this function is tail recursion or not? why?
This is a bit more intricate than a simple left-recursion or tail-call recursion. So I'm wondering how I can eliminate this kind of recursion. I'm already keeping my own stack as you can see below, so the function needs to no params or return values. However, it's still calling itself up (or down) to a certain level and I want to turn this into a loop, but been scratching my head over this for some time now.
Here's the simplified test case, replacing all "real logic" with printf("dostuff at level #n") messages. This is in Go but the problem is applicable to most languages. Use of loops and goto's would be perfectly acceptable (but I played with this and it gets convoluted, out-of-hand and seemingly unworkable to begin with); however, additional helper functions should be avoided. I guess I should to turn this into some kind of simple state machine, but... which? ;)
As for the practicality, this is to run at about 20 million times per second (stack depth can range from 1 through 25 max later on). This is a case where maintaining my own stack is bound to be more stable / faster than the function call stack. (There are no other function calls in this function, only calculations.) Also, no garbage generated = no garbage collected.
So here goes:
func testRecursion () {
var root *TMyTreeNode = makeSomeDeepTreeStructure()
// rl: current recursion level
// ml: max recursion level
var rl, ml = 0, root.MaxDepth
// node: "the stack"
var node = make([]*TMyTreeNode, ml + 1)
// the recursive and the non-recursive / iterative test functions:
var walkNodeRec, walkNodeIt func ();
walkNodeIt = func () {
log.Panicf("YOUR ITERATIVE / NON-RECURSIVE IDEAS HERE")
}
walkNodeRec = func () {
log.Printf("ENTER LEVEL %v", rl)
if (node[rl].Level == ml) || (node[rl].ChildNodes == nil) {
log.Printf("EXIT LEVEL %v", rl)
return
}
log.Printf("PRE-STUFF LEVEL %v", rl)
for i := 0; i < 3; i++ {
switch i {
case 0:
log.Printf("PRECASE %v.%v", rl, i)
node[rl + 1] = node[rl].ChildNodes[rl + i]; rl++; walkNodeRec(); rl--
log.Printf("POSTCASE %v.%v", rl, i)
case 1:
log.Printf("PRECASE %v.%v", rl, i)
node[rl + 1] = node[rl].ChildNodes[rl + i]; rl++; walkNodeRec(); rl--
log.Printf("POSTCASE %v.%v", rl, i)
case 2:
log.Printf("PRECASE %v.%v", rl, i)
node[rl + 1] = node[rl].ChildNodes[rl + i]; rl++; walkNodeRec(); rl--
log.Printf("POSTCASE %v.%v", rl, i)
}
}
}
// test recursion for reference:
if true {
rl, node[0] = 0, root
log.Printf("\n\n=========>RECURSIVE ML=%v:", ml)
walkNodeRec()
}
// test non-recursion, output should be identical
if true {
rl, node[0] = 0, root
log.Printf("\n\n=========>ITERATIVE ML=%v:", ml)
walkNodeIt()
}
}
UPDATE -- after some discussion here, and further thinking:
I just made up the following pseudo-code which in theory should do what I need:
curLevel = 0
for {
cn = nextsibling(curLevel, coords)
lastnode[curlevel] = cn
if cn < 8 {
if isleaf {
process()
} else {
curLevel++
}
} else if curLevel == 0 {
break
} else {
curLevel--
}
}
Of course the tricky part will be filling out nextsibling() for my custom use-case. But just as a general solution to eliminating inner recursion while maintaining the depth-first traversal order I need, this rough outline should do so in some form or another.
I'm not really sure I understand what it is you want to do since your recursion code looks a little strange. However if I understand the structure of your TMyTreeNode then this is what I would do for a non recursive version.
// root is our root node
q := []*TMyTreeNode{root}
processed := make(map[*TMyTreeNode]bool
for {
l := len(q)
if l < 1 {
break // our queue is empty
}
curr := q[l - 1]
if !processed[curr] && len(curr.childNodes) > 0 {
// do something with curr
processed[curr] = true
q = append(q, curr.childNodes...)
continue // continue on down the tree.
} else {
// do something with curr
processed[curr] = true
q := q[:l-2] // pop current off the queue
}
}
NOTE: This will go arbitrarily deep into the structure. If that's not what you want it will need some modifications.
Project Euler problem 14:
The following iterative sequence is
defined for the set of positive
integers:
n → n/2 (n is even) n → 3n + 1 (n is
odd)
Using the rule above and starting with
13, we generate the following
sequence: 13 → 40 → 20 → 10 → 5 → 16 →
8 → 4 → 2 → 1
It can be seen that this sequence
(starting at 13 and finishing at 1)
contains 10 terms. Although it has not
been proved yet (Collatz Problem), it
is thought that all starting numbers
finish at 1.
Which starting number, under one
million, produces the longest chain?
My first instinct is to create a function to calculate the chains, and run it with every number between 1 and 1 million. Obviously, that takes a long time. Way longer than solving this should take, according to Project Euler's "About" page. I've found several problems on Project Euler that involve large groups of numbers that a program running for hours didn't finish. Clearly, I'm doing something wrong.
How can I handle large groups of numbers quickly?
What am I missing here?
Have a read about memoization. The key insight is that if you've got a sequence starting A that has length 1001, and then you get a sequence B that produces an A, you don't to repeat all that work again.
This is the code in Mathematica, using memoization and recursion. Just four lines :)
f[x_] := f[x] = If[x == 1, 1, 1 + f[If[EvenQ[x], x/2, (3 x + 1)]]];
Block[{$RecursionLimit = 1000, a = 0, j},
Do[If[a < f[i], a = f[i]; j = i], {i, Reverse#Range#10^6}];
Print#a; Print[j];
]
Output .... chain length´525´ and the number is ... ohhhh ... font too small ! :)
BTW, here you can see a plot of the frequency for each chain length
Starting with 1,000,000, generate the chain. Keep track of each number that was generated in the chain, as you know for sure that their chain is smaller than the chain for the starting number. Once you reach 1, store the starting number along with its chain length. Take the next biggest number that has not being generated before, and repeat the process.
This will give you the list of numbers and chain length. Take the greatest chain length, and that's your answer.
I'll make some code to clarify.
public static long nextInChain(long n) {
if (n==1) return 1;
if (n%2==0) {
return n/2;
} else {
return (3 * n) + 1;
}
}
public static void main(String[] args) {
long iniTime=System.currentTimeMillis();
HashSet<Long> numbers=new HashSet<Long>();
HashMap<Long,Long> lenghts=new HashMap<Long, Long>();
long currentTry=1000000l;
int i=0;
do {
doTry(currentTry,numbers, lenghts);
currentTry=findNext(currentTry,numbers);
i++;
} while (currentTry!=0);
Set<Long> longs = lenghts.keySet();
long max=0;
long key=0;
for (Long aLong : longs) {
if (max < lenghts.get(aLong)) {
key = aLong;
max = lenghts.get(aLong);
}
}
System.out.println("number = " + key);
System.out.println("chain lenght = " + max);
System.out.println("Elapsed = " + ((System.currentTimeMillis()-iniTime)/1000));
}
private static long findNext(long currentTry, HashSet<Long> numbers) {
for(currentTry=currentTry-1;currentTry>=0;currentTry--) {
if (!numbers.contains(currentTry)) return currentTry;
}
return 0;
}
private static void doTry(Long tryNumber,HashSet<Long> numbers, HashMap<Long, Long> lenghts) {
long i=1;
long n=tryNumber;
do {
numbers.add(n);
n=nextInChain(n);
i++;
} while (n!=1);
lenghts.put(tryNumber,i);
}
Suppose you have a function CalcDistance(i) that calculates the "distance" to 1. For instance, CalcDistance(1) == 0 and CalcDistance(13) == 9. Here is a naive recursive implementation of this function (in C#):
public static int CalcDistance(long i)
{
if (i == 1)
return 0;
return (i % 2 == 0) ? CalcDistance(i / 2) + 1 : CalcDistance(3 * i + 1) + 1;
}
The problem is that this function has to calculate the distance of many numbers over and over again. You can make it a little bit smarter (and a lot faster) by giving it a memory. For instance, lets create a static array that can store the distance for the first million numbers:
static int[] list = new int[1000000];
We prefill each value in the list with -1 to indicate that the value for that position is not yet calculated. After this, we can optimize the CalcDistance() function:
public static int CalcDistance(long i)
{
if (i == 1)
return 0;
if (i >= 1000000)
return (i % 2 == 0) ? CalcDistance(i / 2) + 1 : CalcDistance(3 * i + 1) + 1;
if (list[i] == -1)
list[i] = (i % 2 == 0) ? CalcDistance(i / 2) + 1: CalcDistance(3 * i + 1) + 1;
return list[i];
}
If i >= 1000000, then we cannot use our list, so we must always calculate it. If i < 1000000, then we check if the value is in the list. If not, we calculate it first and store it in the list. Otherwise we just return the value from the list. With this code, it took about ~120ms to process all million numbers.
This is a very simple example of memoization. I use a simple list to store intermediate values in this example. You can use more advanced data structures like hashtables, vectors or graphs when appropriate.
Minimize how many levels deep your loops are, and use an efficient data structure such as IList or IDictionary, that can auto-resize itself when it needs to expand. If you use plain arrays they need to be copied to larger arrays as they expand - not nearly as efficient.
This variant doesn't use an HashMap but tries only to not repeat the first 1000000 numbers. I don't use an hashmap because the biggest number found is around 56 billions, and an hash map could crash.
I have already done some premature optimization. Instead of / I use >>, instead of % I use &. Instead of * I use some +.
void Main()
{
var elements = new bool[1000000];
int longestStart = -1;
int longestRun = -1;
long biggest = 0;
for (int i = elements.Length - 1; i >= 1; i--) {
if (elements[i]) {
continue;
}
elements[i] = true;
int currentStart = i;
int currentRun = 1;
long current = i;
while (current != 1) {
if (current > biggest) {
biggest = current;
}
if ((current & 1) == 0) {
current = current >> 1;
} else {
current = current + current + current + 1;
}
currentRun++;
if (current < elements.Length) {
elements[current] = true;
}
}
if (currentRun > longestRun) {
longestStart = i;
longestRun = currentRun;
}
}
Console.WriteLine("Longest Start: {0}, Run {1}", longestStart, longestRun);
Console.WriteLine("Biggest number: {0}", biggest);
}