Memoization with StateM
Memoization is a technique used to speed up computer programs by storing the results of expensive function calls and returning the cached result when the same inputs occur again. A HashMap combined with a StateM monad is a powerful way to implement this in Lean. StateM allows us to carry the state of our cache (here HashMap) through our computations.
Below is an example of the Fibonacci sequence implemented with memoization.
/--
Recursive Fibonacci with memoization.
We use StateM Memo to carry the cache.
-/
abbrev FibState := HashMap Nat Nat
abbrev FibM := StateM FibState
def fib (n: Nat) : FibM Nat := do
match n with
| 0 => return 1
| 1 => return 1
| k + 2 => do
let m ← get
match m.get? n with -- check if we calculated it before
| some v => return v
| none => do
let v1 ← fib k -- calculate at k and update the state
let v2 ← fib (k + 1)
let v := v1 + v2
modify (fun m => m.insert n v)
return v
/-
- `run` -> calculates given an initial state and
returns the result and the final state
- `run'` -> given an initial state returns the result
-/
#eval fib 350 |>.run' {}
#eval fib 50 |>.run {}
Using memoization allows us to compute fib 350 almost instantly, whereas a naive recursive implementation would take a very long time.