───────────────
The State monad
───────────────

This article tries to be the kind of description Yours Truly would have loved
to find at the point of learning monads but still having trouble with the State
monad.

*Some* knowledge of the concept of monads is required for this to make any
sense. Try to understand simpler monads such as List and Maybe first.

I’m also using the Haskell syntax, but the concept can be used in any language
that has lambdas.

──────────────
The State type
──────────────

A computation that takes an initial state (of type ‘s’) and returns a tuple
consisting of the result value of the computation (of type ‘a’) and a new state
(of type ‘s’) can be expressed as a function of type ‘s -> (a, s)’.

For instance, ‘\s -> ("foo", s)’ would be a computation that returns ‘"foo"’ as
the result value and the state unchanged.

As another example, ‘\s -> (s, s+1)’ would be a computation that returns the
initial state as the result value and the state value incremented by one as the
new state.

For the purposes of abstraction we’ll wrap these computations in the following
type:

  newtype State s a = S (s -> (a, s))

─────────────────────
The runState function
─────────────────────

The runState function takes a state computation and an initial state as
parameters. It unwraps the computation from the State type and simply applies
it to the initial state.

  runState :: State s a -> s -> (a, s)
  runState (S f) s = f s

─────────────────────────────────
The functions return, get and put
─────────────────────────────────

A number of useful functions for generating state computations can be defined
(and are in Haskell’s State module). Let’s look at the most important ones.

Running with the initial state of 42:

expression                 │ returns
───────────────────────────┼────────────
runState (return "foo") 42 │ ("foo", 42)
runState (get)          42 │ (42,    42)
runState (put 43)       42 │ ((),    43)

‘return "foo"’ generates a computation that returns ‘"foo"’ as the result value
and the state unchanged. It is implemented as follows:

  return a = …

We’ll wrap the computation in the State type by doing ‘S (computation)’, or
using the alternative syntax ‘S $ computation’ to avoid parentheses.

  return a = S $ …

As always, the computation is a function that takes the initial state as a
parameter…

  return a = S $ \s -> …

…and returns a tuple consisting of the result value and the new state. In this
case, the result value is the parameter passed to return (‘a’), and the new
state is the initial state unchanged (‘s’).

  return a = S $ \s -> (a,  s)

‘get’ generates a computation that returns the unchanged state as both the
result value and the new state.

  get      = S $ \s -> (s,  s)

‘put 43’ generates a computation that returns whatever as the result value and
‘43’ as the new state.

  put s    = S $ \_ -> ((), s)

────────────────────────────────────────────────
Sequencing state computations: the bind operator
────────────────────────────────────────────────

For combining state computations, we’ll want the >>= (bind) operator to return
a new state computation that, given an initial state,

0. runs the left-hand side computation,
1. passes its result value to the right-hand side
   (in ‘return "foo" >>= \a -> …’, ‘a’ should be ‘"foo"’;
   in ‘get >>= \s -> …’, ‘s’ should be the state),
2. runs the right-hand side computation, passing the new state resulting from
   the first computation as its initial state, and
3. returns the result value and the final state returned by the second
   computation.

It should behave as follows:

  runState f 42 where
    f = do
      s <- get
      put (s+1)
      return "foo"

…being syntactic sugar for…

  runState (get >>= \s -> put (s+1) >>= \_ -> return "foo") 42

…should result in ‘("foo", 43)’.

As >>= should return a new state computation, it should begin just like the
functions defined earlier, that is, by wrapping a function of type
‘s -> (a, s)’ to the State type.

  m >>= k = S $ \s -> …

‘m’ is the first state computation, and ‘k’ takes its evaluated result value
and returns a second state computation.

  m :: State s a
  k :: a -> State s b

We should start by running ‘m’, the left-hand side parameter of the >>= call:

  m >>= k = S $ \s ->
    let (a, s') = runState m s
    in  …

value of m   │ resulting a │ resulting s'
─────────────┼─────────────┼───────────────
return "foo" │ "foo"       │ s  (unchanged)
get          │ s           │ s  (unchanged)
put 43       │ ()          │ 43

The right-hand side parameter ‘k’ expects to get ‘a’, the result value of the
first computation, as a parameter. Evaluating ‘k a’ returns the second state
computation (‘n’).

  m >>= k = S $ \s ->
    let (a, s') = runState m s
        n = k a
    in  …

“s'” is the new state which we’ll need to give as the initial state for the
second computation ‘n’.

  m >>= k = S $ \s ->
    let (a, s')  = runState m s
        n = k a
        (b, s'') = runState n s'
    in  …

…or just…

  m >>= k = S $ \s ->
    let (a, s')  = runState m s
        (b, s'') = runState (k a) s'
    in  …

‘b’ is now the result value of the right-hand side computation, and “s''” is
the final state.

As “(b, s'')” – the final result value and the final state from sequencing the
two state computations – is exactly what we’ll need to return, we finally
arrive at the complete definition of >>=:

  m >>= k = S $ \s ->
    let (a, s') = runState m s
    in  runState (k a) s'

─────────
Afterword
─────────

I’ll appreciate any corrections to the errors this article may contain.

Johan Kiviniemi, 2011

2011-01-31: Yan Tayga seems to have translated this into Russian. :-)
http://yantayga.livejournal.com/16281.html

──────────────────────────────────
Appendix: a working implementation
──────────────────────────────────

┌─────────────────┐
│ ExampleState.hs │
└─────────────────┘

module ExampleState (State(..), evalState, execState, get, put) where

newtype State s a = State { runState :: s -> (a, s) }

evalState f s = fst $ runState f s
execState f s = snd $ runState f s

get   = State $ \s -> (s,  s)
put s = State $ \_ -> ((), s)

instance Monad (State s) where
  return a = State $ \s -> (a, s)
  m >>= k  = State $ \s ->
    let (a, s') = runState m s
    in  runState (k a) s'

┌─────────────┐
│ Examples.hs │
└─────────────┘

module Examples (example0, example1) where

import ExampleState

-- for example1
import Control.Monad
import System.Random

example0 = runState f 42 where
  f = do
    s <- get
    put (s+1)
    return "foo"

example1 = evalState f (mkStdGen 0) where
  f = do
    a <- die 6        -- throw 1d6
    b <- 2 `dice` 6   -- throw 2d6
    c <- 6 `dice` 10  -- throw 6d10
    return [a, b, c]

  -- randomR conveniently follows the ‘s -> (a, s)’ pattern.
  die sides = State $ randomR (1, sides :: Integer)

  n `dice` sides = sum `liftM` replicateM n (die sides)