# 11. State Monad

25 Dec 2014

As of March 2020, School of Haskell has been switched to read-only mode.

The use of the `Either` monad helped us simplify error processing in the last tutorial. I promised to show you how another monad, the state monad, can eliminate explicit symbol-table threading. But before I do that, let's have a short refresher on currying, since it's relevant to the construction of the state monad (there is actually some beautiful math behind the relationship of currying and the state monad).

There are two ways of encoding a two-argument function and, in Haskell, they are equivalent. One is to implement a function that takes two values:

``````fPair :: a -> b -> c
fPair x y = ...``````

The other is to implement a function that takes one argument and returns another function of one argument (parentheses added for emphasis):

``````fCurry :: a -> (b -> c)
fCurry x = \y -> ...``````

This might seem like a trivial transformation, but I'll show you how it can help us in coding the evaluator.

## Curried Evaluator

Let me remind you what the signature of the function `evaluate` was -- to make things simpler, let's consider the version from before the introduction of the `Either` monad:

``evaluate :: Tree -> SymTab -> (Double, SymTab)``

I'm going to parenthesize it the way that highlights the currying interpretation:

``evaluate :: Tree -> (SymTab -> (Double, SymTab))``

Let's read this signature carefully: `evaluate` is a function that takes a `Tree` and returns a function, which takes a `SymTab` and returns a pair `(Double, SymTab)`. What if we take this reading to heart and rewrite `evaluate` so that it actually returns a function (a lambda).

Let's start with the `UnaryNode` evaluator, which used to look like this:

``````evaluate (UnaryNode op tree) symTab =
let (x, symTab') = evaluate tree symTab
in case op of
Plus  -> ( x, symTab')
Minus -> (-x, symTab')``````

and let's try something like this:

``````evaluate :: Tree -> (SymTab -> (Double, SymTab))
evaluate (UnaryNode op tree) =
\symTab ->
let (x, symTab') = {-hi-}evaluate{-/hi-} tree symTab --??
in case op of
Plus  -> ( x, symTab')
Minus -> (-x, symTab')``````

You see what the problem is? In the new scheme, the inner call to `evaluate` will no longer return a pair `(x, symTab')` but a function `(SymTab -> (Double, SymTab))`. Let me call this function `act` for action. How can we extract `x` and `symTab'` from that action? By running it! We do have an argument `symTab` to pass to it -- it's the argument of the lambda:

``````evaluate :: Tree -> (SymTab -> (Double, SymTab))
evaluate (UnaryNode op tree) =
\symTab ->
let act = evaluate tree
(x, symTab') = {-hi-}act symTab{-/hi-}
in case op of
Plus  -> ( x, symTab')
Minus -> (-x, symTab')``````

What have just happened? We called the new `evaluate` only to immediately execute the resulting action? Then why even bother with the intermediate step?

First of all, it's a neat idea that evaluation can be separated into two phases: one for creating a network of functions like `evaluate` calling each other but not actually evaluating the result; and another phase for excecuting this network, starting with a particular state -- the symbol table in this case. Obviously, if you provide a different starting symbol table, you will obtain a different final result. But the network of functions depends only on the original parse tree.

The second reason is that this form brings us closer to our goal of abstracting away the tedium of symbol-table passing. Symbol table passing is what "actions" are supposed to do; `evaluate` should only construct the tracks for the symbol-table train.

Interestingly, this separation between creating an action and running it turned out to be quite useful in C++, as I showed in my old post Monads in C++. There, the actions were constructed at compile time using an EDSL, and then executed at runtime.

Going back to our program, we'll try follow the same procedure we used to derive the `Either` monad. The most important part of a monad is the bind function. Remember, bind is the glue that binds the output of one function to the input of another function -- the one we call a continuation. The signature of bind is determined by the definition of the `Monad` class. It has the form:

``````bind :: Blob a
-> (a -> Blob b)
-> Blob b``````

where `Blob` stands for the type constructor we are trying to monadize. In our case, this type constructor is of the form `(SymTab -> (a, SymTab))`, with the type parameter `a` nested inside the return type of an action. I'll call this function type the new `Evaluator`:

``type Evaluator a = SymTab -> (a, SymTab)``

We'll standardize it later using a `newtype` definition, which is required by `instantiate`, but for now let's just work with a type synonym.

So here's what monadic bind should look like for our type (yes, it's exactly the same as for our `Either` monad, except that `Evaluator` now hides a function):

``````bindS :: Evaluator a
-> (a -> Evaluator b)
-> Evaluator b``````

The client of bind is supposed to pass an evaluator as the first argument and a continuation as the second. The continuation is a function that returns an evaluator. Let's look for this pattern in our implementation of `evaluate` of the `UnaryNode`:

``````evaluate :: Tree -> (SymTab -> (a, SymTab))
evaluate (UnaryNode op tree) =
(\symTab ->
let act = evaluate tree
(x, symTab') = act symTab
in case op of
Plus  -> ( x, symTab')
Minus -> (-x, symTab'))``````

We are looking for a piece of code that can be interpreted as "the rest of code." On first attempt we might think of the following lambda as our continuation:

``````\x' -> case op of
Plus  -> ( x', symTab')
Minus -> (-x', symTab'))``````

but it's the wrong type. Our continuation is supposed to be returning an `Evaluator`, not a pair `(Double, SymTab)`. How can we turn this value into an evaluator? That's what monadic `return` is supposed to do. Its signature, again, is determined by the `Monad` class (I'm calling it `returnS` for now to avoid name conflicts):

``returnS :: a -> Evaluator a``

The implementation is a no-brainer, really. We turn `x` into a function that returns this `x` with a side of `symTab`:

``returnS x = \symTab -> (x, symTab)``

So here's the candidate that fulfills all our requirements for a continuation:

``````\x' -> case op of
Plus  -> return x'
Minus -> return (-x')``````

This is indeed a fine monadic function (returning a value of the soon to be monadic `Evaluator`), and it fits the type signature of the continuation required by bind; except that we don't see it in the original code. We can't carve it out of the current implementation of `evaluate`. If we could only find a way to insert this `returnS` and then immediately cancel it. But how can one undo `returnS`? Well, how about exectuting its result? Check this out:

``````(returnsS x) symTab' = (\symTab -> (x, symTab)) symTab'
= (x, symTab')``````

When you execute a lambda, you simply replace it with its body and replace the formal parameter with the actual argument. Here, I replaced `symTab` (formal parameter, or bound variable) with `symTab'` (the argument). In general, the argument may be a whole expression. You just stick at every place the formal parameter appears in the body. (You have to be careful though not to introduce name conflicts.)

So here's the final rewrite:

``````data Operator = Plus | Minus
data Tree = UnaryNode Operator Tree
type SymTab = ()
-- show
type Evaluator a = SymTab -> (a, SymTab)

returnS :: a -> Evaluator a
returnS x = \symTab -> (x, symTab)

evaluate :: Tree -> (SymTab -> (Double, SymTab))
evaluate (UnaryNode op tree) =
\symTab ->
let act = evaluate tree
(x, symTab') = act symTab
k = \x' -> case op of
Plus  -> returnS x'
Minus -> returnS (-x')
act' = k x
in
act' symTab'

main = putStrLn "It type checks!"``````

If it type checks, it must be correct, right? To convince yourself that this indeed works, first apply `k` to `x` -- this will just replace `x'` with `x`. Then apply the resulting action to `symTab'` to cancel out the `returnS`s.

Let's continue with our program to define a new monad. To this end, we need to identify the pattern we've been looking for. We want to pick the implementation of `bindS` from `evaluate`.

We can clearly see the two arguments to bind: one is `act`, the result of `evaluate tree`, and the other is the continuation `k`. The rest must be bind. Here it is, together with `returnS` and the new version of `evaluate`:

``````data Operator = Plus | Minus
data Tree = UnaryNode Operator Tree
type SymTab = ()
-- show
type Evaluator a = SymTab -> (a, SymTab)

returnS :: a -> Evaluator a
returnS x = \symTab -> (x, symTab)

bindS :: Evaluator a
-> (a -> Evaluator b)
-> Evaluator b
bindS act k =
\symTab ->
let (x, symTab') = act symTab
act' = k x
in
act' symTab'

evaluate :: Tree -> (SymTab -> (Double, SymTab))
evaluate (UnaryNode op tree) =
bindS (evaluate tree)
(\x -> case op of
Plus  -> returnS x
Minus -> returnS (-x))

main = putStrLn "It type checks!"``````

## Symbol Table Monad

Let's formalize what we've done so far using an actual instance of the `Monad` typeclass. First, we need to encapsulate our evaluator type in a `newtype` declaration. This muddles things a little, but is necessary if we want to use it in an `instance` declaration. Here's a type that contains nothing but a function:

``newtype Evaluator a = Ev (SymTab -> (a, SymTab))``

And here are our return and bind functions in their cleaned up form:

``````instance Monad Evaluator where
return x = Ev (\symTab -> (x, symTab))
(Ev act) >>= k = Ev \$ \symTab ->
let (x, symTab') = act symTab
(Ev act')    = k x
in
act' symTab'``````

Now that the paperwork is done, we can start using the `do` notation. Here's our monadic `UnaryNode` evaluator:

``````evaluate (UnaryNode op tree) = do
x <- evaluate tree
case op of
Plus  -> return x
Minus -> return (-x)``````

`SumNode` is even more spectacular:

``````evaluate (SumNode op left right) = do
lft <- evaluate left
rgt <- evaluate right
case op of
Plus  -> return (lft + rgt)
Minus -> return (lft - rgt)``````

Compare it with the original:

``````evaluate (SumNode op left right) symTab =
let (lft, symTab')  = evaluate left symTab
(rgt, symTab'') = evaluate right symTab'
in
case op of
Plus  -> (lft + rgt, symTab'')
Minus -> (lft - rgt, symTab'')``````

All references to the symbol table are magically gone. The code is not only cleaner, but also less error prone. In the original code there were way too many opportunities to use the wrong symbol table for the wrong call. That's all taken care of now.

There are only three places where you'll see explicit use of the symbol table: `lookUp`, `addSymbol`, and the main loop -- as it should be! I recommend studying the complete code for the calculculator listed at the end of this tutorial, with special attention to those functions.

Now you have seen with your own eyes that all this can be done with pure functions. We managed to manipulate state -- the symbol table -- in a purely functional way.

There is a popular misconception that you must use impure code to deal with mutable state, and that Haskell monads are impure. There are ways to introduce impurities in Haskell -- there's a bunch of functions whose names start with unsafe and there is `trace` for debugging, the `ST` monad (not to be confused with the `State` monad), all of which (carefully) let you inject impurity into your code. Sometimes it's done for debugging, sometimes for performance. In general, though, you can and should stick to the purely functional style.

What we have just done is to create our own version of a generic state monad. It was, hopefully, a good learning experience, but one that shouldn't be repeated when writing production code. So let's familiarize ourselves with the `Control.Monad.State` version of the state monad (strictly speaking the state monad is defined using a monad transformer, so the actual code in the library may look a bit different from what I present). State monad is defined by a new type `State`, which is parameterized by two type variables. The first one is used to represent the state (in our case that would be `SymTab`), and the second is the generic type parameter of every monad type constructor.

``newtype State s a = State s -> (a, s)``

`State` has one data constructor also called `State`. It takes a function as an argument. The interesting thing is that this constructor is not exported from the library so you can't pattern match on it. If you want to create a new monadic `State`, use the function `state`:

``state :: (s -> (a, s)) -> State s a``

Instead of extracting an action from `State`, which you can't do, and acting with it on some state, you call the function `runState` which does it for you:

``runState :: State s a -> s -> (a, s)``

The `Monad` instance declaration for `State` looks something like this:

``````instance Monad (State s) where
return x = state (\st -> (x, st))
act >>= k = state \$ \st ->
let (x, st') = runState act st
in runState (k x) st'``````

Notice that `State s` is not a type but a type constructor: it needs one more type variable to become a type. As I mentioned before, `Monad` class can only be instantiated with type constructors.

I've shown you how to extract the bind operator from state-threading code, but there is a more general derivation that's based on types. In Haskell you often see functions whose implementation is determined by their signatures. Sometimes it's determined uniquely, more often we pick the simplest non-trivial implementation that type checks. Here's the signature of `>>=` that is required by the `Monad` class as applied to `State s`:

``````(>>=) :: State s a -> (a -> State s b) -> State s b
act >>= k = ...``````

The first observation is that, in order to run the continuation `k`, we need a value of type `a`. The only source of such value could be the first argument, `act`, and the only way to retrieve it is to call `act` with some state. But we don't have any state yet.

But notice that bind itself doesn't produce a value -- it produces a `State` object. How do you construct a `State`? By calling `state` with a function. Bind must therefore define a lambda of the signature `s -> (b, s)` and pass it to `state`. The outer shell of `>>=` must therefore have the form:

``act >> k = state \$ \st -> ...``

Now, inside the lambda, we do have access to a state variable `st` and we can use it to run `act`.

``````act >> k = state \$ \st ->
let (x, st') = runState act st
...
``````

Now we have `x` of type `a` so we can call the continuation `k`:

``````act >> k = state \$ \st ->
let (x, st') = runState act st
act' = k x
...
``````

The continuation returns an action `act'` of the type `State s b`. Our lambda, though, must return a pair of the type `(b, s)`. The only way to generate a value of the type `b` is to run `act'` with some state. Here we have a choice: we can run it with the original `st` or with the new `st'`. The first choice would mean that the state never changes and, in fact, doesn't even have to be returned by the action. There is a perfectly good monad built on this assumption: it's called the reader monad (see the exercise at the end of this tutorial). But since here we are modeling mutable state, we choose to use `st'` to run `act'`:

``````act >> k = state \$ \st ->
let (x, st') = runState act st
act' = k x
in runState act' st'
``````

There is one more ingredient necessary to make the state monad usable: the ability to access and modify the state. There are two generic functions `get` and `put` that provide this functionality:

``````get :: State s s
get = state \$ \st -> (st, st)

put :: s -> State s ()
put newState = state \$ \_ -> ((), newState)``````

`get` returns the value of the state. `put` returns unit, but has a "side effect" of injecting new state into subsequent computations.

## What Is a Monad?

We've seen two seemingly disparate examples of a monad and I will show you some more in the next tutorial. What do they have in common, other than implementing the functions `return` and `>>=`? Why are these two functions so important? It's time for some deeper insights.

The basic premise of all programming is that you can decompose a complex computation into a set of simpler ones.

The difference between various programming paradigms is in the mechanics of composing smaller computations into larger ones. For instance, in C you use a combination of functions and side effects. You call a function (procedure) whose effects can be:

1. Returning a value
2. Modifying an argument (when it's a reference)
3. Modifying global variables
4. Interacting with the external world

Some of the effects are visible in the signature of the function (types of input and output parameters), others are implicit. The compiler may help with flagging explicit mismatches, but it can't check the implicit ones. So when you're composing functions in C, you have to keep in mind all the hidden interactions between them.

In OO programming, side effects are somewhat tamed with data hiding. Although arguments are mostly passed by reference, including the implicit `this` pointer, the things you can do to them are restricted by their interfaces. Still, hidden dependencies make composition fragile. This is especially painful when dealing with concurrency.

The starting point of functional programming is that functions have no side effects whatsoever, so function composition is a straightforward matter of passing the results of one function as the input to the next. This is a great starting point from the point of composability. However, many of the traditional notions of computation don't have straightforward translations into pure functions. This has been a huge problem in the adoption of functional languages.

Two things happened (not necessarily in that order) to change this situation:

1. We learned how to translate most computations into functions.
2. We use monads to abstract the tedium of this translation.

I tried to emphasize the same two steps when introducing monads.

First, I showed you how to translate partial computations into total functions. These functions encapsulate their results into `Maybe` or `Either` types. I also showed you how to deal with mutable state by passing it as an additional parameter into and out of a function.

This is a very general pattern: Take a computation that takes input and produces output but does it in a non-functional way, and modify input and output types in such a way that the computation becomes functional.

Next, I showed you a way to do the same thing by modifying only the return types of the computation. If the translation of a computation required adding input parameters to the original signature (passing the symbol table in, for instance), I used currying and turned the output type into a function type. (In Exercise 1 you'll use the same trick used to implement the reader monad.)

So this is lesson one: A computation can be turned into a function that encapsulates the originally non-functional bits into its modified (decorated, fancified, or whatever you call it) output data type.

The great thing about it is that now all this additional information is visible to the compiler and the type checker. There is even a name for this system in type theory: the effect system. A function signature may expose the effects of a function in addition to just turning input into output types. These effects are propagated when composing functions (as was the effect of modifying the symbol table, or being undefined for some values of arguments) and can be checked at compile time.

A potential shortcoming of this approach is that the composition of such fancified functions requires writing some boilerplated code. In the case of `Maybe`- or `Either`-returning functions, we have to pattern match the results and fork the execution. In case of action-returning functions, we need to run these actions, provide the additional parameters they need, and pass results to the next action.

To our great relief, this highly repetitive (and error-prone) glue code can be abstracted into just two functions: `>>=` and `return` (optionally `>>` and `fail`). Now we can test our implementation of the glue code in one place, or still better, use the library code. And to make our lives even better, we have this wonderful syntactic sugar in the shape of the `do` notation.

But now, when you look at a do block, it looks very much like imperative code with hidden side effects. The `Either` monadic code looks like using functions that can throw exceptions. `State` monad code looks as if the state were a global mutable variable. You access it using `get` with no arguments, and you modify it by calling `put` that returns no value. So what have we gained in comparison to C?

We might not see the hidden effects, but the compiler does. It desugars every do block and type-checks it. The state might look like a global variable but it's not. Monadic bind makes sure that the state is threaded from function to function. It's never shared. If you make your Haskell code concurrent, there will be no data races.

## Exercises

Ex 1. Define the reader monad. It's supposed to model computations that have access to some read-only environment. In imperative code such environment is often implemented as a global object. In functional languages we need to pass it as an argument to every function that might potentially need access to it. The reader monad hides this process.

``````newtype Reader e a = Reader (e -> a)

reader :: (e -> a) -> Reader e a
reader f = undefined

runReader :: Reader e a -> e -> a

...

type Env = Reader String
-- curried version of
-- type Env a = Reader String a

test :: Env Int
test = do
return \$ read s + 1

main = print \$ runReader test "13"``````

Ex 2. Use the `State` monad from `Control.Monad.State` to re-implement the evaluator.

``````import Data.Char
import qualified Data.Map as M

data Operator = Plus | Minus | Times | Div
deriving (Show, Eq)

data Tree = SumNode Operator Tree Tree
| ProdNode Operator Tree Tree
| AssignNode String Tree
| UnaryNode Operator Tree
| NumNode Double
| VarNode String
deriving Show

type SymTab = M.Map String Double

type Evaluator a = State SymTab a

lookUp :: String -> Evaluator Double
lookUp str = do ...

addSymbol :: String -> Double -> Evaluator ()
addSymbol str val = do ...

evaluate :: Tree -> Evaluator Double

evaluate (SumNode op left right) = ...

evaluate (ProdNode op left right) = ...

evaluate (UnaryNode op tree) = ...

evaluate (NumNode x) = ...

evaluate (VarNode str) = ...

evaluate (AssignNode str tree) = ...

expr = AssignNode "x" (ProdNode Times (VarNode "pi")
(ProdNode Times (NumNode 4) (NumNode 6)))

main = print \$ runState (evaluate expr) (M.fromList [("pi", pi)])``````

## Calculator with the Symbol Table Monad

Here's the complete runnable version of the calculator that uses our Symbol Table Monad.