─────────────── The State monad ─────────────── This article tries to be the kind of description Yours Truly would have loved to find at the point of learning monads but still having trouble with the State monad. *Some* knowledge of the concept of monads is required for this to make any sense. Try to understand simpler monads such as List and Maybe first. I’m also using the Haskell syntax, but the concept can be used in any language that has lambdas. ────────────── The State type ────────────── A computation that takes an initial state (of type ‘s’) and returns a tuple consisting of the result value of the computation (of type ‘a’) and a new state (of type ‘s’) can be expressed as a function of type ‘s -> (a, s)’. For instance, ‘\s -> ("foo", s)’ would be a computation that returns ‘"foo"’ as the result value and the state unchanged. As another example, ‘\s -> (s, s+1)’ would be a computation that returns the initial state as the result value and the state value incremented by one as the new state. For the purposes of abstraction we’ll wrap these computations in the following type: newtype State s a = S (s -> (a, s)) ───────────────────── The runState function ───────────────────── The runState function takes a state computation and an initial state as parameters. It unwraps the computation from the State type and simply applies it to the initial state. runState :: State s a -> s -> (a, s) runState (S f) s = f s ───────────────────────────────── The functions return, get and put ───────────────────────────────── A number of useful functions for generating state computations can be defined (and are in Haskell’s State module). Let’s look at the most important ones. Running with the initial state of 42: expression │ returns ───────────────────────────┼──────────── runState (return "foo") 42 │ ("foo", 42) runState (get) 42 │ (42, 42) runState (put 43) 42 │ ((), 43) ‘return "foo"’ generates a computation that returns ‘"foo"’ as the result value and the state unchanged. It is implemented as follows: return a = … We’ll wrap the computation in the State type by doing ‘S (computation)’, or using the alternative syntax ‘S $ computation’ to avoid parentheses. return a = S $ … As always, the computation is a function that takes the initial state as a parameter… return a = S $ \s -> … …and returns a tuple consisting of the result value and the new state. In this case, the result value is the parameter passed to return (‘a’), and the new state is the initial state unchanged (‘s’). return a = S $ \s -> (a, s) ‘get’ generates a computation that returns the unchanged state as both the result value and the new state. get = S $ \s -> (s, s) ‘put 43’ generates a computation that returns whatever as the result value and ‘43’ as the new state. put s = S $ \_ -> ((), s) ──────────────────────────────────────────────── Sequencing state computations: the bind operator ──────────────────────────────────────────────── For combining state computations, we’ll want the >>= (bind) operator to return a new state computation that, given an initial state, 0. runs the left-hand side computation, 1. passes its result value to the right-hand side (in ‘return "foo" >>= \a -> …’, ‘a’ should be ‘"foo"’; in ‘get >>= \s -> …’, ‘s’ should be the state), 2. runs the right-hand side computation, passing the new state resulting from the first computation as its initial state, and 3. returns the result value and the final state returned by the second computation. It should behave as follows: runState f 42 where f = do s <- get put (s+1) return "foo" …being syntactic sugar for… runState (get >>= \s -> put (s+1) >>= \_ -> return "foo") 42 …should result in ‘("foo", 43)’. As >>= should return a new state computation, it should begin just like the functions defined earlier, that is, by wrapping a function of type ‘s -> (a, s)’ to the State type. m >>= k = S $ \s -> … ‘m’ is the first state computation, and ‘k’ takes its evaluated result value and returns a second state computation. m :: State s a k :: a -> State s b We should start by running ‘m’, the left-hand side parameter of the >>= call: m >>= k = S $ \s -> let (a, s') = runState m s in … value of m │ resulting a │ resulting s' ─────────────┼─────────────┼─────────────── return "foo" │ "foo" │ s (unchanged) get │ s │ s (unchanged) put 43 │ () │ 43 The right-hand side parameter ‘k’ expects to get ‘a’, the result value of the first computation, as a parameter. Evaluating ‘k a’ returns the second state computation (‘n’). m >>= k = S $ \s -> let (a, s') = runState m s n = k a in … “s'” is the new state which we’ll need to give as the initial state for the second computation ‘n’. m >>= k = S $ \s -> let (a, s') = runState m s n = k a (b, s'') = runState n s' in … …or just… m >>= k = S $ \s -> let (a, s') = runState m s (b, s'') = runState (k a) s' in … ‘b’ is now the result value of the right-hand side computation, and “s''” is the final state. As “(b, s'')” – the final result value and the final state from sequencing the two state computations – is exactly what we’ll need to return, we finally arrive at the complete definition of >>=: m >>= k = S $ \s -> let (a, s') = runState m s in runState (k a) s' ───────── Afterword ───────── I’ll appreciate any corrections to the errors this article may contain. Johan Kiviniemi, 2011 2011-01-31: Yan Tayga seems to have translated this into Russian. :-) http://yantayga.livejournal.com/16281.html ────────────────────────────────── Appendix: a working implementation ────────────────────────────────── ┌─────────────────┐ │ ExampleState.hs │ └─────────────────┘ module ExampleState (State(..), evalState, execState, get, put) where newtype State s a = State { runState :: s -> (a, s) } evalState f s = fst $ runState f s execState f s = snd $ runState f s get = State $ \s -> (s, s) put s = State $ \_ -> ((), s) instance Monad (State s) where return a = State $ \s -> (a, s) m >>= k = State $ \s -> let (a, s') = runState m s in runState (k a) s' ┌─────────────┐ │ Examples.hs │ └─────────────┘ module Examples (example0, example1) where import ExampleState -- for example1 import Control.Monad import System.Random example0 = runState f 42 where f = do s <- get put (s+1) return "foo" example1 = evalState f (mkStdGen 0) where f = do a <- die 6 -- throw 1d6 b <- 2 `dice` 6 -- throw 2d6 c <- 6 `dice` 10 -- throw 6d10 return [a, b, c] -- randomR conveniently follows the ‘s -> (a, s)’ pattern. die sides = State $ randomR (1, sides :: Integer) n `dice` sides = sum `liftM` replicateM n (die sides)