Nested functions

LANG

A “recent” article about the struggle to get nested functions and lambdas into C got me curious as to how it is implemented in Ada, where nested functions can be called like any other as well as passed to functions expecting function pointers. Specifically, we'll see how GNAT does it on x86-64.

The TL;DR is that nested functions are implemented as a kind of closure, where the function itself expects a pointer to the stack frame in the r10 register. If we need to go multiple scopes out, then the outer frame pointer is expected to be at the bottom of the frame, effectively creating a linked list of frames we can traverse. To let nested functions be passed as function pointers, we create a structure on the stack containing the actual function pointer and the necessary frame pointer, create a pointer to that structure, set its lowest bit, and send that pointer as the function pointer. Any call to a function pointer now needs to check if the lowest bit is set, and if so, dereference to find the frame pointer, put that in r10, then call the actual function pointer.

Stacky closures

The first example we will look at is rather simple:

function Id (Num : Integer) return Integer is
  function Inner return Integer is (Num);
begin
  return Inner;
end Id;

Note that the Inner is implicitly a function call, unlike in C. If we wanted the address of the function, we'd write Inner'Access.

Let's look at the generated assembly for the above function. The code for Inner looks something like1:

id__inner:
        push    rbp
        mov     rbp, rsp
        mov     [rbp-8], r10    # save the frame
        mov     eax, [r10]      # return frame->Num
        pop     rbp
        ret     

The calling convention expects the return value as a 32-bit integer in the eax register, so the important line is the mov eax, [r10]. Nested functions expect what we might call a frame pointer in r10, which points to some place in the stack frame. Further, Id.Inner expects the X variable to be the first value stored at that frame pointer. The stack should therefore look something like this when we're executing Inner:

,---------------------,
| (Inner stack frame) |
|---------------------|
| ...                 |
| ...                 |
| ...                 |
|---------------------| <-- end of Id stack frame
| Num                 |
|---------------------| <-- r10 (the frame pointer)
| ...                 |

If we had more stack-allocated variables, they'd be just above Num, such that we could access them by a simple indexing operation.

In the generated assembly, there's also a line mov [rbp-8], r10. This just spills that frame pointer to the stack in case we need to modify the r10 register (such as if we need to call a nested function that's in a wider scope).

The generated assembly for the Id function itself is slightly more complicated. It needs to put the X parameter on the stack (since it is passed by register in this calling convention), then create the appropriate frame pointer which it can pass to Id.Inner before it can actually call that function.

id:
        push    rbp
        mov     rbp, rsp
        sub     rsp, 32           # allocate some stack space
        lea     rax, [rbp+16]
        mov     [rbp-8], rax      # save the frame base parent
        mov     [rbp-16], edi     # spill 'Num' to the stack
        lea     rax, [rbp-16]     # create the frame pointer
        mov     r10, rax
        call    id__inner
        leave
        ret

The parameter Num is passed in the edi register. Note that we put the address rbp+16 as a stack variable – this is the frame base parent (sometimes called a dynamic link) and is mostly useful for dynamically sized stack allocated variables2.

Following that, we spill Num onto the stack, and create a pointer to that stack data. If we had more stack variables, we'd put them next to Num here. This leaves us with a stack that looks something like this just before the call:

,---------------------, <-- rsp
| Num                 |
|---------------------| <-- r10  (frame pointer)
| (frame base parent) | -->------+
|=====================| <-- rbp  |
| prev. rbp           |          |
|---------------------|          |
| return pointer      |          |
|---------------------| <--------+
| ...                 |

With this setup, we have access to the outer variable Num from a nested function via the r10 register. This matches the setup that Inner expects above, so everything works out nicely.

Linked lists

The next problem we face is that of accessing variables that are more than one scope out. Consider a slightly modified version of the above example.

function Id (Num : Integer) return Integer is
  function Inner return Integer is
    function Extra_Inner return Integer is (Num);
  begin
    return Extra_Inner;
  end Inner;
begin
  return Inner;
end Id;

Now, Inner itself contains a nested function Extra_Inner, which references the Num two scopes up. As we saw, when we call a nested function, we build up a frame with the data and then send a pointer to that frame to the function. Now, Inner could copy the Num onto its own stack frame and take a pointer there, but if we change the example such that Extra_Inner modifies it instead of just returning it, we'll see that it has access to the variable itself.

Instead of copying parameters onto its own stack, GNAT does something slightly more sophisticated. The idea is that we build up a linked list of stack frames, which we can then traverse the required number of times to find the necessary variable. Note that this linked list corresponds to the static lexical scope of the functions, not the dynamic scope. Even if we add recursion to Inner, that linked list stays the same length.

Effectively, we want something like this:

,---------------------------,
| (Extra_Inner stack frame) |
|===========================| <== Extra_Inner
| (stack data)              |
|---------------------------|
| previous frame pointer    | -->--------,
|---------------------------| <-- r10    |
| (stack data)              |
|===========================| <== Inner  |
| (stack data)              |            |
|---------------------------|            |
| Num                       |            |
|---------------------------| <----------'
| (stack data)              |
|===========================| <== Id
| ...                       |

Now, Extra_Inner can find that outer Num by first dereferencing r10 to get the frame of Inner, then finding the frame pointer and dereferencing that. We get a linked list of frames:

r10 -> Inner frame -> Id frame

If we look at the generated for Extra_Inner, we'll see it in action:

id__inner__extra_inner:
        push    rbp
        mov     rbp, rsp
        mov     [rbp-8], r10    # save the frame
        mov     rax, [r10]      # frame = frame[0]
        mov     eax, [rax]      # return frame->Num
        pop     rbp
        ret

As before, we spill the frame to the stack, and then traverse it twice to get Num. The body of Inner requires a bit more care now, since we expect that previous frame pointer to be stored at the frame pointer we are given in r10. To illustrate, let's check out what happens to Extra_Inner if we rewrite Inner like so:

function Id (Num : Integer) return Integer is
  function Inner (Y : Integer) return Integer is
    function Extra_Inner return Integer is (Num + Y);
  begin
    return Extra_Inner;
  end Inner;
begin
  return Inner(10);
end Id;

Now, Extra_Inner references variables from both its immediate outer scope (the Y from Inner's frame) and the one outside there (the Num from Id's frame). The generated assembly is a bit heavier, though that is largely due to overflow checking. The relevant section looks like this3:

        mov     rdx, [r10]      # prev = frame[0]
        mov     edx, [rdx]      # num = prev[0]
        mov     eax, [r10+8]    # y = frame[8]
        add     eax, edx
        leave
        ret

As mentioned, Inner needs to do a bit more work to call Extra_Inner, since it needs to save the previous frame like a stack local, but it is largely similar to before:

id__inner:
        push    rbp
        mov     rbp, rsp
        sub     rsp, 48         # allocate some stack space
        mov     [rbp-48], r10   # save the frame
        lea     rax, [rbp+16]
        mov     [rbp-16], rax   # save the frame base parent
        mov     [rbp-24], edi   # frame[8] = y
        mov     [rbp-32], r10   # frame[0] = prev
        lea     rax, [rbp-32]   # create a new frame pointer
        mov     r10, rax
        call    id__inner__extra_inner
        leave
        ret

At this point, it might be obvious why we need to save the frame with mov [rbp-48], r10, since we overwrite it before we pass it to Extra_Inner. Also note that the code for the Inner without the Y parameter is largely the same, it just doesn't put anything but the previous frame pointer on the stack.

Just before the call, the stack will look something like this:

,--------------------,
| Y                  |
|--------------------|
| prev frame pointer | -->--...
|--------------------| <-- r10
| frame base parent  | -->--...
|====================| <== Inner
| ...                |

Which is what Extra_Inner expects.

Note that since the stack grows downwards on x86, things are sort of flipped but also not. In the actual code, the order of Y and the previous frame pointer is the same, but they will be below the frame base parent, not above.

Function descriptors

The last piece of the puzzle is how we handle pointers to nested functions. There are many approaches one could take here (some worse than others), including Rust-style monomorphisation. GNAT takes an approach which lets nested functions (and closures in general) be represented as single pointers.

This is possible for one important reason: code is aligned on 16-byte boundaries. This means that the lowest few bits of a pointer to a function (or a function pointer, as you could call it) are always unset. GNAT exploits this by using the lowest bit as a function descriptor – if it is set, then the pointer isn't actually a function pointer, but a nested function pointer, which needs special treatment. If it is unset, it is a bog standard function pointer which for which we can do a simple call.

The main idea is that we create a structure on the stack which contains the actual function pointer as well as the frame pointer it expects. This is pretty similar to a typical closure structure, with the frame pointer being the data that's closed over. As such, function descriptors are a technique that can be used to compile arbitrary closures into a single function pointer4.

Let's look at an example:

function Id (Num : Integer) return Integer is
  type F is not null access function return Integer;
  function Inner return Integer is (Num);
  function Run (Func : F) return Integer is (Func.all);
begin
  return Run (Inner'Access);
end Id;

not null access function return Integer is Ada's verbose way of saying “non-null pointer to a function returning an integer”, Func.all dereferences the function (and implicitly calls it, since it takes no arguments), and Inner'Access gets the pointer to the Inner function.

Inner is compiled just like before:

id__inner:
        push    rbp
        mov     rbp, rsp
        mov     [rbp-8], r10    # save the frame
        mov     eax, [r10]      # return frame->Num
        pop     rbp
        ret

Run is a bit more involved. One thing to notice here is that although there is some special logic involved in calling the function pointer, this is not because of the particulars of the example program but because any function pointer could potentially be a pointer to a nested function.

id__run:
        push    rbp
        mov     rbp, rsp
        sub     rsp, 16         # allocate some stack space
        mov     [rbp-8], rdi    # spill Func to the stack
        mov     [rbp-16], r10   # save the frame given to Run
        mov     rax, rdi
        mov     rdx, rax
        and     edx, 1
        test    rdx, rdx        # is the lowest bit set?
        je      .L7             # jump to .L7 if not
        mov     r10, [rax-1]    # frame = Func[-1]
        mov     rax, [rax+7]    # Func  = Func[7]
.L7:
        call    rax             # call Func
        leave
        ret

In Run, we check if the lowest bit is set, and call the nested function appropriately if so. In that case, Run expects Func to point to a structure that looks like this

,------------------, <-- Func
| frame pointer    |
|------------------|
| function pointer |
'------------------'

Which let's us access the frame pointer as Func[0] and the function pointer as Func[8]. However, since the lowest bit is set, the pointer we are given is actually Func + 1, which means the frame pointer is at the offset -1 and the function pointer at 7.

Finally, Id looks like this:

id:
        push    rbp
        mov     rbp, rsp
        sub     rsp, 48             # allocate some stack space
        lea     rax, [rbp+16]
        mov     [rbp-8], rax        # save the frame base parent
        mov     [rbp-32], edi       # spill 'Num' to the stack

        # create the frame pointer 'Inner' expects
        lea     rax, [rbp-32]
        mov     rdx, rax
        add     rax, 8

        # create the nested function structure
        mov     [rax], rdx          # store the frame pointer
        mov     [rax+8], id__inner  # store the function pointer

        # create the pointer to that structure
        lea     rax, [rbp-32]
        add     rax, 8              # point in-between frame & function
        add     rax, 1              # set the lowest bit

        # call 'Run'
        mov     r10, rdx
        mov     rdi, rax
        call    id__run
        leave
        ret

Things are a bit more convoluted, but the gist is that we store Num on the stack at rbp-32, to which the frame pointer points. Additionally, we create a closure-like structure containing the function pointer (at rbp-16) and the frame pointer (at rbp-24). Finally, the “function pointer” we pass to Run actually points to that structure, so we set its lowest bit such that we don't accidentally try to call the stack.

That gives us this stack layout just before the call:

,------------------,
| function pointer |
|------------------|
| frame pointer    | -->------+
|------------------| <-- rdi  |
| Num              |          |
|------------------| <-- r10 -+
| ...              |

Of course, rdi is off by its lowest bit, and there's a frame base parent and other things that complicate this picture, but this is the general idea.

Both r10 and the frame pointer Inner expects point to Num. This is because Run is also a nested function which happens to require the same frame pointer as Inner, since they're declared in the same scope.

Why does this work?

All of this relies on two things being true:

  1. Nested functions are inaccessible from the outside
  2. Pointers to nested functions cannot escape their scope

This is perhaps the major difference between nested functions and closures, with the former largely being a special case of the latter. As long as both the points above hold, however, a nested function doesn't actually need to capture the specific outer variables it uses, since it can access them by reference via the stack frames. And because nested functions are stuck within their scope, that protects against use-after-frees.

One could take the C approach to these problems and call it undefined behaviour if a nested function was used outside its scope, but GNAT is in fact able to prevent it from happening entirely at compiler time5.

The way in which GNAT and Ada in general is able to do this is by firstly limiting nested function pointers to only be assignable to pointer types declared in their scope or a child one. This means that this is legal:

procedure Outer is
  type P is access procedure;
  procedure Inner is
  begin
    null;
  end Inner;

  Ptr : P := Inner'Access;
begin
  null;
end Outer;

while this is not:

type P is access procedure;
procedure Outer is
  procedure Inner is
  begin
    null;
  end Inner;

  Ptr : P := Inner'Access;
begin
  null;
end Outer;

Since all functions and such must have their types fully annotated, and a type specified within a function is inaccessible outside, a pointer to a nested function cannot escape that way. It does also mean that a higher-order function declared outside a given scope cannot accept nested functions from within the scope as long as the function type they expect is named.

Things get a bit more hairy with anonymous access types. Functions taking such types can take nested functions from narrower scopes as arguments. In this case, GNAT just disallows any assignments or attempts to convert such function pointer types to a named one, which prevents storing them.