When is tail recursion guaranteed in Rust? - recursion

C language
In the C programming language, it's easy to have tail recursion:
int foo(...) {
return foo(...);
}
Just return as is the return value of the recursive call. It is especially important when this recursion may repeat a thousand or even a million times. It would use a lot of memory on the stack.
Rust
Now, I have a Rust function that might recursively call itself a million times:
fn read_all(input: &mut dyn std::io::Read) -> std::io::Result<()> {
match input.read(&mut [0u8]) {
Ok ( 0) => Ok(()),
Ok ( _) => read_all(input),
Err(err) => Err(err),
}
}
(this is a minimal example, the real one is more complex, but it captures the main idea)
Here, the return value of the recursive call is returned as is, but:
Does it guarantee that the Rust compiler will apply a tail recursion?
For instance, if we declare some variable that needs to be destroyed like a std::Vec, will it be destroyed just before the recursive call (which allows for tail recursion) or after the recursive call returns (which forbids the tail recursion)?

Shepmaster's answer explains that tail call elimination is merely an optimization, not a guarantee, in Rust. But "never guaranteed" doesn't mean "never happens". Let's take a look at what the compiler does with some real code.
Does it happen in this function?
As of right now, the latest release of Rust available on Compiler Explorer is 1.39, and it does not eliminate the tail call in read_all.
example::read_all:
push r15
push r14
push rbx
sub rsp, 32
mov r14, rdx
mov r15, rsi
mov rbx, rdi
mov byte ptr [rsp + 7], 0
lea rdi, [rsp + 8]
lea rdx, [rsp + 7]
mov ecx, 1
call qword ptr [r14 + 24]
cmp qword ptr [rsp + 8], 1
jne .LBB3_1
movups xmm0, xmmword ptr [rsp + 16]
movups xmmword ptr [rbx], xmm0
jmp .LBB3_3
.LBB3_1:
cmp qword ptr [rsp + 16], 0
je .LBB3_2
mov rdi, rbx
mov rsi, r15
mov rdx, r14
call qword ptr [rip + example::read_all#GOTPCREL]
jmp .LBB3_3
.LBB3_2:
mov byte ptr [rbx], 3
.LBB3_3:
mov rax, rbx
add rsp, 32
pop rbx
pop r14
pop r15
ret
mov rbx, rax
lea rdi, [rsp + 8]
call core::ptr::real_drop_in_place
mov rdi, rbx
call _Unwind_Resume#PLT
ud2
Notice this line: call qword ptr [rip + example::read_all#GOTPCREL]. That's the (tail) recursive call. As you can tell from its existence, it was not eliminated.
Compare this to an equivalent function with an explicit loop:
pub fn read_all(input: &mut dyn std::io::Read) -> std::io::Result<()> {
loop {
match input.read(&mut [0u8]) {
Ok ( 0) => return Ok(()),
Ok ( _) => continue,
Err(err) => return Err(err),
}
}
}
which has no tail call to eliminate, and therefore compiles to a function with only one call in it (to the computed address of input.read).
Oh well. Maybe Rust isn't as good as C. Or is it?
Does it happen in C?
Here's a tail-recursive function in C that performs a very similar task:
int read_all(FILE *input) {
char buf[] = {0, 0};
if (!fgets(buf, sizeof buf, input))
return feof(input);
return read_all(input);
}
This should be super easy for the compiler to eliminate. The recursive call is right at the bottom of the function and C doesn't have to worry about running destructors. But nevertheless, there's that recursive tail call, annoyingly not eliminated:
call read_all
Tail call optimization is not guaranteed to happen in C, either. No compiler I tried would be convinced to turn this into a loop on its own initiative.
Since version 13, clang supports a non-standard musttail attribute you can add to tail calls that should be eliminated. Adding this attribute to the C code successfully eliminates the tail call. However, rustc currently has no equivalent attribute (although the become keyword is reserved for this purpose).
Does it ever happen in Rust?
Okay, so it's not guaranteed. Can the compiler do it at all? Yes! Here's a function that computes Fibonacci numbers via a tail-recursive inner function:
pub fn fibonacci(n: u64) -> u64 {
fn f(n: u64, a: u64, b: u64) -> u64 {
match n {
0 => a,
_ => f(n - 1, a + b, a),
}
}
f(n, 0, 1)
}
Not only is the tail call eliminated, the whole fibonacci_lr function is inlined into fibonacci, yielding only 12 instructions (and not a call in sight):
example::fibonacci:
push 1
pop rdx
xor ecx, ecx
.LBB0_1:
mov rax, rdx
test rdi, rdi
je .LBB0_3
dec rdi
add rcx, rax
mov rdx, rcx
mov rcx, rax
jmp .LBB0_1
.LBB0_3:
ret
If you compare this to an equivalent while loop, the compiler generates almost the same assembly.
What's the point?
You probably shouldn't be relying on optimizations to eliminate tail calls, either in Rust or in C. It's nice when it happens, but if you need to be sure that a function compiles into a tight loop, the surest way, at least for now, is to use a loop.

Neither tail recursion (reusing a stack frame for a tail call to the same function) nor tail call optimization (reusing the stack frame for a tail call to any function) are ever guaranteed by Rust, although the optimizer may choose to perform them.
if we declare some variable that needs to be destroyed
It's my understanding that this is one of the sticking points, as changing the location of destroyed stack variables would be contentious.
See also:
Recursive function calculating factorials leads to stack overflow
RFC 81: guaranteed tail call elimination
RFC 1888: Proper tail calls

Related

Return a pointer at a specific position - Assembly

I am a beginner in Assembly and i have a simple question.
This is my code :
BITS 64 ; 64−bit mode
global strchr ; Export 'strchr'
SECTION .text ; Code section
strchr:
mov rcx, -1
.loop:
inc rcx
cmp byte [rdi+rcx], 0
je exit_null
cmp byte [rdi+rcx], sil
jne .loop
mov rax, [rdi+rcx]
ret
exit_null:
mov rax, 0
ret
This compile but doesn't work. I want to reproduce the function strchr as you can see. When I test my function with a printf it crashed ( the problem isn't the test ).
I know I can INC rdi directly to move into the rdi argument and return it at the position I want.
But I just want to know if there is a way to return rdi at the position rcx to fix my code and probably improve it.
Your function strchr seems to expect two parameters:
pointer to a string in RDI, and
pointer to a character in RSI.
Register rcx is used as index inside the string? In this case you should use al instead of cl. Be aware that you don't limit the search size. When the character refered by RSI is not found in the string, it will probably trigger an exception. Perhaps you should test al loaded from [rdi+rcx] and quit further searching when al=0.
If you want it to return pointer to the first occurence of character
inside the string, just
replace mov rax,[rdi+rcx] with lea rax,[rdi+rcx].
Your code (from edit Version 2) does the following:
char* strchr ( char *p, char x ) {
int i = -1;
do {
if ( p[i] == '\0' ) return null;
i++;
} while ( p[i] != x );
return * (long long*) &(p[i]);
}
As #vitsoft says, your intention is to return a pointer, but in the first return (in assembly) is returning a single quad word loaded from the address of the found character, 8 characters instead of an address.
It is unusual to increment in the middle of the loop.  It is also odd to start the index at -1.  On the first iteration, the loop continue condition looks at p[-1], which is not a good idea, since that's not part of the string you're being asked to search.  If that byte happens to be the nul character, it'll stop the search right there.
If you waited to increment until both tests are performed, then you would not be referencing p[-1], and you could also start the index at 0, which would be more usual.
You might consider capturing the character into a register instead of using a complex addressing mode three times.
Further, you could advance the pointer in rdi and forgo the index variable altogether.
Here's that in C:
char* strchr ( char *p, char x ) {
for(;;) {
char c = *p;
if ( c == '\0' )
break;
if ( c == x )
return p;
p++;
}
return null;
}
Thanks to your help, I finally did it !
Thanks to the answer of Erik, i fixed a stupid mistake. I was comparing str[-1] to NULL so it was making an error.
And with the answer of vitsoft i switched mov to lea and it worked !
There is my code :
strchr:
mov rcx, -1
.loop:
inc rcx
cmp byte [rdi+rcx], 0
je exit_null
cmp byte [rdi+rcx], sil
jne .loop
lea rax, [rdi+rcx]
ret
exit_null:
mov rax, 0
ret
The only bug remaining in the current version is loading 8 bytes of char data as the return value instead of just doing pointer math, using mov instead of lea. (After various edits removed and added different bugs, as reflected in different answers talking about different code).
But this is over-complicated as well as inefficient (two loads, and indexed addressing modes, and of course extra instructions to set up RCX).
Just increment the pointer since that's what you want to return anyway.
If you're going to loop 1 byte at a time instead of using SSE2 to check 16 bytes at once, strchr can be as simple as:
;; BITS 64 is useless unless you're writing a kernel with a mix of 32 and 64-bit code
;; otherwise it only lets you shoot yourself in the foot by putting 64-bit machine code in a 32-bit object file by accident.
global mystrchr
mystrchr:
.loop: ; do {
movzx ecx, byte [rdi] ; c = *p;
cmp cl, sil ; if (c == needle) return p;
je .found
inc rdi ; p++
test cl, cl
jnz .loop ; }while(c != 0)
;; fell out of the loop on hitting the 0 terminator without finding a match
xor edi, edi ; p = NULL
; optionally an extra ret here, or just fall through
.found:
mov rax, rdi ; return p
ret
I checked for a match before end-of-string so I'd still have the un-incremented pointer, and not have to decrement it in the "found" return path. If I started the loop with inc, I could use an [rdi - 1] addressing mode, still avoiding a separate counter. That's why I switched up the order of which branch was at the bottom of the loop vs. your code in the question.
Since we want to compare the character twice, against SIL and against zero, I loaded it into a register. This might not run any faster on modern x86-64 which can run 2 loads per clock as well as 2 branches (as long as at most one of them is taken).
Some Intel CPUs can micro-fuse and macro-fuse cmp reg,mem / jcc into a single load+compare-and-branch uop for the front-end, at least when the memory addressing mode is simple, not indexed. But not cmp [mem], imm/jcc, so we're not costing any extra uops for the front-end on Intel CPUs by separately loading into a register. (With movzx to avoid a false dependency from writing a partial register like mov cl, [rdi])
Note that if your caller is also written in assembly, it's easy to return multiple values, e.g. a status and a pointer (in the not-found case, perhaps to the terminating 0 would be useful). Many C standard library string functions are badly designed, notably strcpy, to not help the caller avoid redoing length-finding work.
Especially on modern CPUs with SIMD, explicit lengths are quite useful to have: a real-world strchr implementation would check alignment, or check that the given pointer isn't within 16 bytes of the end of a page. But memchr doesn't have to, if the size is >= 16: it could just do a movdqu load and pcmpeqb.
See Is it safe to read past the end of a buffer within the same page on x86 and x64? for details and a link to glibc strlen's hand-written asm. Also Find the first instance of a character using simd for real-world implementations like glibc's using pcmpeqb / pmovmskb. (And maybe pminub for the 0-terminator check to unroll over multiple vectors.)
SSE2 can go about 16x faster than the code in this answer for non-tiny strings. For very large strings, you might hit a memory bottleneck and "only" be about 8x faster.

Recursion in Assembly x86: Fibonacci

I am trying to code a recursive fibonacci sequence in assembly, but it is not working for some reason.
It does not give an error, but the output number is always wrong.
section .bss
_bss_start:
; store max iterations
max_iterations resb 4
_bss_end:
section .text
global _start
; initialise the function variables
_start:
mov dword [max_iterations], 11
mov edx, 0
push 0
push 1
jmp fib
fib:
; clear registers
mov eax, 0
mov ebx, 0
mov ecx, 0
; handle fibonacci math
pop ebx
pop eax
add ecx, eax
add ecx, ebx
push eax
push ebx
push ecx
; incriment counter / exit contitions
inc edx
cmp edx, [max_iterations]
je print
; recursive step
call fib
ret
print:
mov eax, 1
pop ebx
int 0x80
For instance, the above code prints a value of 79 rather than 144 (11th fibonacci number).
Alternatively, if I make
mov dword [max_iterations], 4
Then the above code prints 178 rather than 5 (5th fibonacci number).
Any one have an idea?
K
As an approach, you should try to debug it with the smallest possible input, like 1 iteration.  That will be most revealing as you can watch it do the wrong thing in great detail without worrying about multiple recursing's.  When that works, go to 2 iterations.
When you use complex addressing modes, it is harder to debug as we cannot see what the processor is doing.  So, when an instruction using a complex addressing mode doesn't work, and you want to debug it, then split that instruction into 2 instructions as follows:
mov dword [fibonacci_seq + edx + 4], ecx
---
lea esi, [fibonacci_seq + edx + 4]
mov [esi], ecx
With the alternate code sequence, you can observe the value of the addressing mode computation, which will provide you with additional debugging insight.
As another example:
cmp edx, [max_iterations]
---
mov edi, [max_iterations]
cmp edx, edi
Using the 2 instruction version, you will be able to see what value the processor is comparing edx with.
Or better, do that that mov load once before the loop, so you're keeping the loop bound in a register all the time. That's what you should normally do when you have enough registers, only using memory when you run out.
You are jmping to fib from one place in the code and calling it from another.  Though your logic should work because when you've reached the limit, you don't return to the main, this is really bad form: to mix main code with function.  More on that below...
mov dword [fibonacci_seq + edx + 4], ecx
Is this working for you? You're only incrementing edx by 1.  Perhaps you wanted:
mov dword [fibonacci_seq + edx * 4], ecx
I would argue that your code is not really recursive.
call fib ; jumps to fib, pushes a return address
ret ; never, ever reached, so, pointless
---
jmp fib ; iterate w/o pushing unwanted return address onto the stack
The 1-instruction jmp will be superior to the call as a mechanism to iterate, in part b/c it doesn't push an unnecessary return address onto the stack.
When you debug with 2 iterations, you'll probably see that the unused return address pushed by the call messes up your "parameter" passing, pops.
To expand on the "recursion", when the iteration stops and control transfers to print, there will be some 11 (depending on iteration count) unused return addresses on the stack (modulo the interference by the pop's and pushes).
The recursive call is only used for iteration, the recursion never unwinds.  Thus, I would argue it's not recursive (not even tail recursive) — it just erroneously pushes some unused return addresses onto the stack — that's not recursion.
This line is your main problem:
mov dword [fibonacci_seq + edx + 4], ecx
Because of the +4, you never write to the first entry of the "array". And because you only increment EDX by 1, each write to the array overwrites 3 bytes of the previous entry. Try this instead:
mov dword [fibonacci_seq + edx * 4], ecx
A bit of redesign, as I did not realise that the call instruction used the stack in this way, and the solution is here
section .bss
_bss_start:
; store max iterations and current iteration
max_iterations resb 4
iteration resb 4
; store arguments
n_0 resb 4
n_1 resb 4
_bss_end:
section .text
global _start
; initialise the function variables
_start:
mov dword [max_iterations], 11
mov dword [iteration], 0
mov dword [n_0], 0
mov dword [n_1], 1
jmp fib
fib:
mov ecx, 0
mov edx, 0
mov eax, [n_0]
mov ebx, [n_1]
add ecx, eax
add ecx, ebx
mov edx, [n_1]
mov dword [n_0], edx
mov dword [n_1], ecx
mov edx, [iteration]
inc edx
mov dword [iteration], edx
cmp edx, [max_iterations]
je print
call fib
ret
print:
mov eax, 1
mov ebx, [n_1]
int 0x80

How to print LOCAL byte with WinApi's WriteConsole

It's hard for me to clarify my question, but I'll try. I'm trying to learn MASM32 and I have a task to print some text in console without using .data or .const. The problem is that LOCAL puts variable on stack, but not in static memory. So i cant get their address (offset), and WriteConsole uses a pointer to text's address in memory. Any thoughts on how to deal with this problem? Thanks!
I have this:
.data
string db 10 'somestring'
.code
WriteToConsole PROC
LOCAL handle :DWORD
invoke GetStdHandle, -11
mov handle, eax
mov edx, offset string
invoke WriteConsoleA, handle, edx, 10, 0, 0
xor eax, eax
ret
WriteToConsole ENDP
And I want something like that:
.code
WriteToConsole PROC
LOCAL string[10] :SBYTE
LOCAL handle :DWORD
invoke GetStdHandle, -11
mov handle, eax
mov edx, offset string ;impossible because of stack
invoke WriteConsoleA, handle, edx, 10, 0, 0 ;can't call without a pointer
xor eax, eax
ret
WriteToConsole ENDP```
Well, I found an answer:
LOCAL string[10] :DWORD
lea edx, string
invoke WriteConsoleA, handle, edx, stringlength, 0, 0
Loading effective address instead of an offset helps!

Which value will be returned by calling func

I have some misunderstanding about esp pointer.
below are the code which shown in one of previous exams.
The returned value is 1.
func: xor eax,eax
call L3
L1: call dword[esp]
inc eax
L2: ret
L3: call dword[esp]
L4: ret
Now, I will explain how I think and hope someone will correct me or approve.
This is how I think when I know what is the answer so I am not sure I`m thinking correctly at all.
eax = 0
We push to stack return address which is the next line, i.e label L1.
We jump to L3.
We push to stack return address which is the next line, i.e label L4.
We jump to L1.
We push to stack return address which is the next line, i.e inc eax.
We jump to L4.
We jump to line where is inc eax and stack is now empty.
eax = 1.
we end here(at label L2) and return 1.
I think eax = 2 and the caller of func is called once from the instruction at L1. In the following I will trace execution to show you what I mean.
I re-arranged your example a bit to make it more readable. This is NASM source. I think this should be equivalent to the original (assuming the D bit and B bit are set, i.e. you're running in normal 32-bit mode).
bits 32
func:
xor eax, eax
call L3
.returned:
L1:
call near [esp]
.returned:
inc eax
L2:
retn
L3:
call near [esp]
.returned:
L4:
retn
Now assume we start with some function that does this:
foo:
call func
.returned:
X
retn
This is what happens:
At foo we call func. The address of foo.returned is put on the stack (say, stack slot -1).
At func we set eax to zero.
Next we call L3. The address of func.returned = L1 is put on the stack (slot -2).
At L3 we call the dword on top of the stack. Here this is func.returned. The address of L3.returned = L4 is put on the stack (slot -3).
At L1 we call the dword on top of the stack. Here this is L3.returned. The address of L1.returned is put on the stack (slot -4).
At L4 we return. This pops L1.returned (from slot -4) into eip.
At L1.returned we do inc eax, setting eax to 1.
Then at L2 we return. This pops L3.returned (from slot -3) into eip.
At L4 we return. This pops func.returned (from slot -2) into eip.
At L1 we call the dword on top of the stack. Here this is foo.returned. The address of L1.returned is put on the stack (slot -2).
At foo.returned we execute whatever is there which I marked X. Assuming that the function returns using a retn eventually then...
... we return. This pops L1.returned (from slot -2) into eip.
At L1.returned we do inc eax. Assuming that X did not alter eax then we now have eax = 2.
Then at L2 we return. This pops foo.returned (from slot -1) into eip.
If my assumptions are correct then eax is 2 in the end.
Note that it is really strange to call the return address on top of the stack. I cannot imagine a practical use for this.
Also note that if, in a debugger, you proceed-step the call to func in foo then at step 11 the debugger might return control to the user, with eax equal to 1. However, the stack is not balanced at this point.

Issue with strchr() function implementation

I've recently started looking into assembly code and I'm trying to recode some basic system functions to get a grip on it, I'm currently stuck on a segmentation fault at 0x0 on my strchr.
section .text
global strchr
strchr:
xor rax, rax
loop:
cmp BYTE [rdi + rax], 0
jz end
cmp sil, 0
jz end
cmp BYTE [rdi + rax], sil
jz good
inc rax
jmp loop
good:
mov rax, [rdi + rcx]
ret
end:
mov rax, 0
ret
I can't figure out how to debug it using GDB, also the documentation I've came across is pretty limited or hard to understand.
I'm using the following main in C to test
extern char *strchr(const char *s, int c);
int main () {
const char str[] = "random.string";
const char ch = '.';
char *ret;
ret = strchr(str, ch);
printf("%s\n", ret);
printf("String after |%c| is - |%s|\n", ch, ret);
return(0);
}
The Problem
The instruction immediately following the good label:
mov rax, [rdi + rcx]
should actually be:
lea rax, [rdi + rax]
You weren't using rcx at all, but rax and, what you need is the address of that position, not the value at that position (i.e. lea instead of mov).
Some Advice
Note that the typical idiom for comparing sil against zero is actually test sil, sil instead of cmp sil, 0. It would be then:
test sil, sil
jz end
However, if we look at the strchr(3) man page, we can find the following:
char *strchr(const char *s, int c);
The terminating
null byte is considered part of the string, so that if c is specified as '\0', these functions return a pointer to the terminator.
So, if we want this strchr() implementation to behave as described in the man page, the following code must be removed:
cmp sil, 0
jz end
The typical zeroing idiom for the rax register is neither mov rax, 0 nor xor rax, rax, but rather xor eax, eax, since it doesn't have the encode the immediate zero and saves one byte respect to the latter.
With the correction and the advice above, the code would look like the following:
section .text
global strchr
strchr:
xor eax, eax
loop:
; Is end of string?
cmp BYTE [rdi + rax], 0
jz end
; Is matched?
cmp BYTE [rdi + rax], sil
jz good
inc rax
jmp loop
good:
lea rax, [rdi + rax]
ret
end:
xor eax, eax
ret

Resources