Wednesday, March 5, 2008

A problem with monads?

It took me a while, but I've finally gotten to the point I wanted to make.

Take a look at the state monad:
(define (bind M f)
  (lambda (ma)
    (let* ((intermediate (M ma))
           (i-state (get-state intermediate))
           (i-value (get-value intermediate))
           (next    (f i-value)))
      (next i-state))))

.
We would use this monad to chain together a bunch of procedures that will pass a `state variable' around. The procedures aren't actually `chained' though, they are `nested'. In order to extend the chain, we put a wrapper around the existing chain, catch its return value and feed the appropriate elements through to the next function in the chain. The structure is like a set of Russian dolls where the initial function is the smallest and the subsequent functions wrap around it.

When we run this monad, the first thing we have to do is invoke the next inner monad. Of course this monad has to invoke its next inner monad, etc. etc. until we get to the innermost one. The problem is this: each invocation is consuming resources that cannot be released until the inner invocations finish. So before we even start executing the first chained procedure we are allocating resources for every chained procedure that comes after it. A sequence of chained procedures will allocate the stack frame for the last procedure first, and this frame will be retained until the entire monadic computation is complete. If a sequence has, say, twenty chained procedures, we'll push twenty stack frames as the very first operation.

The state monad is not tail recursive.

I suppose this is a trivial point for some people, but I find this to be remarkable. The lack of tail recursion would place limits on the number of procedures that could be chained together in this way. It is unlikely a person would write a program that exceeds the limit and overflows the stack, but it is not at all impossible for a meta-computation to do so. (Imagine an x86 emulator that used the state monad and chained together the procedures that implemented the instruction set.) What I find remarkable is that there appears to be no underlying logical reason why this isn't tail recursive. The chain of procedures is obviously linear and the resources for the final element in the chain aren't really used until that element is computed, yet they are the first to be allocated.

I have some notions about why this isn't tail recursive, but I'm not enough of an expert in monads or category theory to be sure, yet. I'm hoping that someone who is more of an expert can validate or refute the problem. I think the answer is one of these:
  1. Yes, it is a problem and
    • tail-recursion is just incompatible with monads, or
    • monads are so cool that we're willing to live with the problem, or
    • only Haskell people use monads in real life, and in the theoretical realm you can have as much stack as you can imagine
  2. No, it isn't a problem because you can simply rewrite the state monad like this ...
  3. No, it isn't a problem, just wrap your code in the `tail recursion' monad and magic will happen.
  4. You're an idiot because this obvious thing you clearly didn't understand fixes it.
Any opinions or ideas?

8 comments:

Edward Kmett said...

Its a combination of 2 and 3. ;)

In short, State is fine if you're lazy. In a strict setting you want a CPS transformed version of it.

CPS transforming it yields the 'codensity monad' of the state monad in question which is more or less equivalent to forall r. ContT r (State s) a in Haskell.

The reason for using the codensity monad rather than ContT is to keep you from abusing its call/cc.

There is probably a far more idiomatic construction in lisp.

jrm said...

Hmmm. I understand the different words you are using, but the way you are combining them leaves me baffled. I'm a newbie when it comes to categories, so you'll have to help me out.

In short, State is fine if you're lazy. In a strict setting you want a CPS transformed version of it.

I assume you mean lazy/strict in the technical sense. So you are saying that tail recursion doesn't matter when you are using a lazy language (because it isn't yet a well-defined concept), but in a strict language you'd perform a CPS transform to restore the tail recursion. Did I get this right?

The `codensity monad' stuff has me completely baffled, but I'm guessing it is something like this: CPS transformation is the programmers version of Yoneda embedding. Since all control structure has a CPS equivalent, the State monad itself must have a CPS equivalent. When you use Yoneda embedding, you use the `codensity monad' of the thing you are embedding. In this case, you are embedding the State monad, so you get the codensity monad of the State monad. How does this sound?

The reason for using the codensity monad rather than ContT is to keep you from abusing its call/cc.

This is because the codensity monad `encapsulates' the image of the State monad and hides the continuation. If you expose the continuation to the State monad, you could, for example, use it to backtrack instead of for the purpose of implementing State. Right?

Cale Gibbard said...

Just to let you know, you're using the word "monad" strangely in the latter part of this article. The word "monad" refers primarily to a type constructor, and secondarily to the implementations of return and bind (or fmap, return and join if you prefer) that come along with it, but not to the individual actions whose type is constructed using that monad. (Call them state computations or state actions in this case.) It's similar to how in mathematics, you wouldn't refer to the elements of a group as "groups", unless somehow they actually were.

In any event, basically all Haskell implementations use lazy evaluation, which means that they evaluate outermost reducible expressions first, substituting the parameters into the body. (If a parameter occurs more than once in the body of a function, some provisions are made to ensure that work done evaluating one of the copies will be shared with the others.)

In lazy evaluation, the "stack" takes on a different meaning: it's essentially the depth of the outermost reducible subexpression. From an alternate standpoint, that's the number of nested 'case' expressions which can't yet match one of their patterns. (By contrast, you can think of 'let' as putting new expressions on the heap without evaluating them.)

newtype State s a = S (s -> (s,a))

runState x s = case x of S f -> f s

return v = S (\s -> (s,v))

x >>= f = S (\s -> let (s',v) = runState x s in runState (f v) s')

get = S (\s -> (s,s))

put s' = S (\s -> (s',()))

The only case used in these definitions is the one in the definition of runState. *All* of the primitive computations, including bind, immediately produce something of the form (S f), so the scrutinee of that case expression will basically always be evaluated enough in just one step (or one, plus the number of steps required to reach a primitive definition, if the thing which runState is applied to is not built directly from these primitives).

Thus, when evaluating the runState of a chain of binds and primitive State computations, the stack will basically never grow larger than one element. On the other hand, if a State computation makes a lot of modifications to the state without inspecting it until the very end, and those modifications consist of strict functions, the expression representing the current state may grow to be very large and have a very deep outermost redex, which can cause a stack overflow. But by that point, the state monad machinery will have all been carried out.

One can help to prevent this by using the so-called strict state monad whose definition is the same as the above except that:

x >>= f = S (\s -> case runState x s of (s',v) -> runState (f v) s')

But even in this case, a long chain of binds will not cause a stack overflow. The stack depth here will reflect the *leftward* nesting depth of binds, but more commonly the next bind will occur inside the lambda on the right, so there's little to fear. The stack will grow in a similar (though not quite identical) way to standard imperative languages.

This discussion is all *prior* to optimisation as well.

Now, it may be the case that copying these definitions directly into a language with strict evaluation has problems... but, that's the case with many things. As Edward points out, there are other implementations of these primitives which would fare better under strict evaluation.

Edward Kmett said...

> I assume you mean lazy/strict in the technical sense. So you are saying that tail recursion doesn't matter when you are using a lazy language (because it isn't yet a well-defined concept), but in a strict language you'd perform a CPS transform to restore the tail recursion. Did I get this right?

In a call-by-need setting, as Cale pointed out the stack works differently so State is well behaved there. And yes you got it right. =)

> The `codensity monad' stuff has me completely baffled, but I'm guessing it is something like this: CPS transformation is the programmers version of Yoneda embedding. Since all control structure has a CPS equivalent, the State monad itself must have a CPS equivalent. When you use Yoneda embedding, you use the `codensity monad' of the thing you are embedding. In this case, you are embedding the State monad, so you get the codensity monad of the State monad. How does this sound?

Remarkably close and it is admittedly pretty obscure stuff. =) Actually the Yoneda embedding is just a particular right Kan extension, and so is the codensity monad of a given functor. In category-extras in Haskell I used to define both as specializations of a more general right Kan extension operator 'Ran'. Ran of a functor along itself yields (more or less) a CPS transform, with the aforementioned minor quibble about quantification. Ran of a functor along identity yields the Yoneda embedding. So you have the right intuition that they are similar. =)

Ran f f ~ Codensity f
Ran Identity f ~ Yoneda f

http://hackage.haskell.org/packages/archive/category-extras/0.53.5/doc/html/src/Control-Functor-KanExtension.html
http://hackage.haskell.org/packages/archive/category-extras/0.53.5/doc/html/src/Control-Monad-Codensity.html
http://hackage.haskell.org/packages/archive/category-extras/0.53.5/doc/html/src/Control-Functor-Yoneda.html

Ran is just a generalization of ContT to let the two functors involved in its type signature to vary, and then a restriction on the use of the type you are stuffing into the result functor.

> This is because the codensity monad `encapsulates' the image of the State monad and hides the continuation. If you expose the continuation to the State monad, you could, for example, use it to backtrack instead of for the purpose of implementing State. Right?

Exactly. the ContT monad yields a slightly 'larger' type than the original because you can abuse the continuation. quantifying over it prevents you from directly invoking the continuation with anything but bottom, which keeps backtracking and other craziness from seeping in and enlarging the scope of what you are allowed. In essence it enforces the fact that you are CPS transforming for purely administrative purposes.

jrm said...

You're using the word "monad" strangely in the latter part of this article.

Oops. Yes, I am. When you're using a monad, you project a set of operations into a space where they all have the same type so you can plug them together like combinators. The monad is the space, not the operations within it.

In lazy evaluation, the "stack" takes on a different meaning...

I'm going to add option number 5: space usage in lazy languages is much more difficult to reason about than in strict languages, so it isn't clear what tail-recursion actually means.

Here's a question. This let expression is used to sequence the call chain in a `forward' order.

x >>= f = S (\s -> let (s',v) = runState x s in runState (f v) s')

If you wrote your code `backward', couldn't you re-order the way this is written to avoid the let and prevent long chains of binds?

Muad`Dib said...

# You're an idiot because this obvious thing you clearly didn't understand fixes it.

Muad`Dib said...

Hello, I have written about this topic a couple of times. It's quite a subtle thing so I certainly don't think you are an idiot for noticing it.

Here is my view on how it works in the context of Haskell:

http://muaddibspace.blogspot.com/2008/08/tail-call-optimization-doesnt-exist-in.html
http://muaddibspace.blogspot.com/2008/08/tail-calls-dont-exist-so-why-look-for.html


And this explains why we do not use monads in applicative language (like Scheme) (well sometimes they are useful, for example nondeterminism like in the Kanren implementation).

Brian said...

This is a problem, but I don't think it's a problem with *monads*, I think it's a problem with Ocaml. The Ocaml compiler does almost no higher-lever optimizations of the code- it has a fairly good instruction selection and register allocation code, but any optimizations of a higher sort are absent. This forces the programmer to contort his code in order to optimize it- in many places, not just with monads. For example, consider the "natural" definition of map:

let map f = function
| [] -> []
| x :: xs -> (f x) :: (map f xs)


Now, every experienced Ocaml programmer is going "Argh! That's not how you write it- you write it like this!" followed by one or another contorted version of map which is tail recursive or otherwise more efficient. But that's exactly my point. The compiler *should* be able to transform the nice, simple, clear expression of the solution into an efficient one. It just doesn't. And that bites us in a thousand different ways.

State monads, and pretty much all other monads, being non-tail-recursive, is just one of the thousand.