Lecture 5.3: Type Invariants
Last class, we gained more practice with abstract types: how to write them using modules and module types, and how we can model various data structures and existing concepts using abstract types. Today, we will see how abstract types are useful for program correctness. Recall our slogan: "abstract types prevent users from doing bad things". Is there a way to make this more precise? What is a "bad thing"?
Let's think of a simple example first.
Suppose you wanted a datatype of "even integers". (For example, these are the quantities you can always split into two even parts --- useful if you are working with resources that should be divided between two processors.) We can do this by using abstract types: we will have a type even, and we will know --- by design --- that every value of type even really represents an even integer.
OK, suppose we have this type even. What operations can we do with it? That is, what is its API?
We should certainly be able to create even integers (say, starting from 0 and 2). And we should certainly be able to add even integers together, since that will stay even. But we definitely can't add an even integer to an arbitrary one, since that might not be even.
We can translate that into a module type by including the allowed operations, and not including the ones we shouldn't be able to do:
module type EvenSig = sig
type t
val zero : t
val two : t
val add : t -> t -> t
val mult : t -> int -> t
val get : t -> int
end
To create a t, there are two base cases: zero and two. Then, we can add
ts together, and we can also multiply t's by arbitrary integers (since
that preserves the property of being even). We also have a function get, for turning a t back into an int.
Now that we have the module type, we can go back to our main question: what is the "bad thing" that we want to prevent users from doing?
We want to make sure that if x : t, then get x is never odd. Turned around: we want an invariant (something that should always be true) that if x : t, then (get x) mod 2 = 0.
Assuming that the functions are implemented in the ways we expect, this invariant should be true.
Now, let's implement EvenSig. To do this, we have to choose a type for t. Since this type is meant to encode even integers, we can just let the type be int itself:
module Even : EvenSig = struct
type t = int
let zero = 0
let two = 2
let add x y = x + y
let mult x y = x * y
let get i = i
end
Now, how can we check that the invariant always holds? There are two options:
Checking Invariants via Traces
One way is to use property-based testing. Let's randomly create values of type Even.t, and see if they actually are even.
let rec mk_even (n : int) : Even.t =
if n <= 0 then
(* Base cases *)
if Random.bool () then Even.zero else Even.two
else
if Random.bool () then
Even.add (mk_even (n - 1)) (mk_even (n - 1))
else
Even.mult (mk_even (n - 1)) (Random.int 50)
(* Should always return true, for a valid implementation of Even *)
let check_even (n : int) : bool =
let e : Even.t = mk_even n in
(Even.get n) mod 2 = 0
Checking Invariants via Dynamic Checks
Instead of doing property-based testing, we can instead allow our functions to throw an exception if they ever see a "bad" value of type even:
module EvenChecked : EvenSig = struct
type even = int
let with_check (f : unit -> even) : even =
let x = f () in
if x mod 2 = 0 then x else failwith "INTERNAL ERROR: Invariant does not hold!"
let mk i = with_check (fun _ -> if i mod 2 = 0 then Some i else None)
let add x y = with_check (fun _ -> x + y)
let mult x y = with_check (fun _ -> x * y)
let get i = i
end
If our invariant actually does hold, then we won't see the internal error, and EvenChecked will behave exactly the same as Even. But if there was a bug, this will eventually catch it.
Question
Pros/cons of using PBT versus using a wrapper that can throw an exception?
Invariants on Data Structures
Using invariants, we can justify the correctness of more sophisticated data structures.
For example, let's think of a map from integers to strings. One option for this map is to use a simple association list of type (int * string) list, where the value can be anywhere in the list. While this version is correct, it is inefficient: since the element can be anywhere, looking up any key can potentially cause us to traverse the entire list. If we instead kept the list sorted, then we are allowed to stop traversing the list once we find a key larger than the one we want. Additionally, let's enforce that the list does not have duplicates for the same key, since that makes the list harder to reason about.
(For actual efficiency, one would use a sorted binary search tree, like we saw in HW3; how to handle that is more of the same.)
First, let's come up with a signature for maps:
module type IntMap = sig
type t
val well_formed : t -> bool
val empty : t
val add : int -> string -> t -> t
val remove : int -> t -> t
val lookup : int -> t -> string option
end
empty/add/remove/lookup, we also have a well_formed function, which has type t -> bool. This is generically baking in the notion of an invariant: it should always be true that well_formed x holds, for every value x of type t.
Now, let's see the implementation using sorted lists:
module SortedIntMap : IntMap = struct
type t = (int * string) list
(* Is the list sorted without duplicates? *)
let rec well_formed xs =
match xs with
| [] -> true
| [_] -> true
| p1 :: p2 :: xs' -> fst p1 < fst p2 && well_formed (p2 :: xs')
let empty = []
(* Add is like "insert" from insertion sort, but we also ensure that there are no duplicates. *)
let rec add i x xs =
match xs with
| [] -> [(i, x)]
| p :: xs' ->
(* If i is less than p's key, insert here. *)
if i < fst p then (i, x) :: p :: xs'
(* If i = p's key, then replace with the new element. *)
else if i = fst p then (i, x) :: xs'
(* Otherwise, keep inserting. *)
else p :: add i x xs'
let rec lookup i xs =
match xs with
| [] -> None
| p :: xs' ->
(* If we've found the key, we are done. *)
if fst p = i then Some (snd p) else
if fst p < i then lookup i xs' (* If i is bigger than the current key, keep looking. *)
else None (* Otherwise, i is _less_ than the current key; since the list is sorted, we know that key i can't be in the list. *)
let remove i xs =
(* Could be made more efficient as well, similar to lookup and add. *)
List.filter (fun p -> fst p <> i) xs
end
The point here is that the correctness of add and lookup (and remove, if we optimized it) crucially depends on the invariant that the list is sorted without duplicates. Thus, we need to be certain that the operations (add, lookup, and remove) preserve these invariants. In other words, it should always be true that well_formed holds. Let's check it, in the same way we did for the Even module:
let mk_string () = "abc" (* String doesn't need to be random, since invariant is about keys *)
let gen_key () = Random.int 50
let rec gen_map (n : int) SortedIntMap.t =
if n <= 0 then SortedIntMap.empty else
let m = gen_map (n - 1) in
if Random.bool () then
SortedIntMap.add (gen_key ()) (gen_string ()) m
else
SortedIntMap.remove (gen_key ()) m
let test_map (n : int) : bool =
let m = gen_map n in
SortedIntMap.well_formed m