Lecture 3.3: More PBT
Last class, we talked about the correctness of sorting algorithms. This class, we will see another case study: the correctness of traversals of binary trees.
Let's take a look at a type of binary tree in OCaml:
type entry = {
key : string;
value : int
}
type bintree =
| Leaf of entry
| Branch of bintree * bintree
This binary tree, at its leaves, stores two things: a key, given by a string
; and a value, given by an int
. We store these two entries in a record called entry
for convenience.
As an example, this binary tree could be storing account balances: the key is the account name, while the value is the amount of money in the account.
If we are the bank, we could be concerned with how much money is in all accounts that start with "A" (think of "A" here as a tag of a certain kind of account). We could do it directly:
let is_tagged (s : string) =
if String.length s >= 1 then
String.get s 0 = 'A'
else false
let rec sum_tagged (t : bintree) =
match t with
| Leaf e -> if is_tagged e.key then e.value else 0
| Branch (t1, t2) -> sum_tagged t1 + sum_tagged t2
String.get
has type string -> int -> char
, where char
is the type of a single character, such as 'A'
. )
While this works, we are finding that it is very slow, because the total set of accounts is large (say, 20 gigabytes).
However, you noticed that the set of accounts tagged with 'A'
is actually small: maybe only a few thousand!
This gives you an idea: let's precompute the 'A'
accounts, store them to the side, and then sum those up separately.
let rec collect_tagged (t : bintree) : entry list =
match t with
| Leaf e -> if is_tagged e.key then [e] else []
| Branch (t1, t2) -> collect_tagged t1 @ collect_tagged t2
let rec sum_collected (xs : entry list) =
match xs with
| [] -> 0
| x::xs' -> x.value + sum_collected xs'
While we are still traversing the whole data structure during collect_tagged
,
we only need to do this once. Once we have the list of entries tagged with
'A'
, we can then cheaply compute statistics with them (such as the sum), by
storing this list of entries on disk.
How can we reason about the correctness of our optimization? We need to specify it, by showing that the optimized code is equivalent to the original:
Now that we have a quantified formula we want the truth value of, we can use PBT to see if it's actually true.
The first step of applying PBT is to turn the universal quantifier into a random sampling, so let's write a generator for bintree
:
let tags = ["A_alice"; "B_dave"; "C_carl"; "A_amy"; "B_bob"; "A_anna"; "C_cara"; "A_arthur"; "B_brian"; "C_cindy"]
let gen_entry () =
let k = List.nth tags (Random.int (List.length tags)) in
let v = Random.int 500 in
{ key = k; value = v }
let rec gen_bintree (n : int) = fun () ->
if n <= 0 then Leaf (gen_entry ()) else
if Random.int 10 < 7 then
Branch (gen_bintree (n - 1) (), gen_bintree (n - 1) ())
else Leaf (gen_entry ())
The generator we wrote for gen_bintree
is a little smarter than a previous generator we've seen for binary trees. Instead of generating balanced binary trees, this generator will generate imbalanced ones (where one branch might hold more leaves than another).
We do this by rolling a die from 0..9
, and asking if the result is less than 7
. Thus, 8/10 times, we will continue with the branch. But 2/10 times, we will stop.
Thus, if we had code that only worked with balanced trees --- but did something wrong with an imbalanced one -- we would catch this bug.
Next, we want to implement our specification, for each t
:
Finally, we will pull in the logic from last class for repeatedly calling the sampler to evaluate the specification:
let rec forall_bintree (gen : unit -> bintree) (prop : bintree -> bool) (trials : int) : bintree option =
if trials <= 0 then
None (* No counter-example found *)
else
let xs = gen () in
if prop xs then
forall_bintree gen prop (trials - 1)
else Some xs
At this point, we can see that our specification actually holds:
Working with averages, floating point
Great! Now, let's do a similar spec for computing the average.
let rec average_tagged (t : bintree) : float =
match t with
| Leaf e -> if is_tagged e.key then float_of_int e.value else 0.0
| Branch (t1, t2) -> (average_tagged t1 +. average_tagged t2) /. 2.0
let average_collected (xs : entry list) =
if List.length xs = 0 then 0.0 else
(float_of_int (sum_collected xs)) /. (float_of_int (List.length xs))
In the first function, we compute the average recursively: the average of a leaf is the leaf itself, while the average of a branch is computed by recursively computing the branch averages, and averaging those. In the second function, we flatten the tree out and compute the average on the flattened tree. We are careful to avoid dividing by zero, by treating that as a special case.
Is this right? Let's see:
utop # forall_bintree (gen_bintree 1) (fun t -> average_tagged t = average_collected (collect_tagged t)) 32;;
- : bintree option =
Some
(Branch (Leaf {key = "A_alice"; value = 277},
Leaf {key = "B_dave"; value = 252}))
Oops! What went wrong? The way we are recursively computing the average is dividing by two even if there are no tags on one of the two subtrees, while the flattened computation filters out to only contain the entries tagged with 'A'
.
How can we fix average_tagged
to be correct?
let rec average_tagged' (t : bintree) : float =
let rec sum_and_count (t : bintree) : int * int =
match t with
| Leaf e -> if is_tagged e.key then (e.value, 1) else (0, 0)
| Branch (t1, t2) ->
let (s1, c1) = sum_and_count t1 in
let (s2, c2) = sum_and_count t2 in
(s1 + s2, c1 + c2)
in
let (sum, count) = sum_and_count t in
if count = 0 then 0.0 else (float_of_int sum) /. (float_of_int count)
We will compute the sum recursively along with a count of how many elements we want to average over. We can see if this works:
utop # forall_bintree (gen_bintree 8) (fun t -> average_collected (collect_tagged t) = average_tagged' t) 32;;
- : bintree option = None
Indeed, it does!
A Note about floating point
An earlier version of the code for this lecture forgot to check if we were dividing by zero:
let average_collected (xs : entry list) =
(float_of_int (sum_collected xs)) /. (float_of_int (List.length xs))
This caused issues with our final theorem about average_tagged'
. Let's see what went wrong:
utop # let t = (Branch (Leaf {key = "C_cindy"; value = 446},
Leaf {key = "B_brian"; value = 154}));;
val t : bintree = ...
utop # average_tagged' t;;
- : float = 0.
utop# average_collected (collect_tagged t);;
- : float = nan
While average_tagged'
returned 0.
(as expected), we see that the flattened version returned NaN (or not a number), which happens when we do an invalid operation.
The other "invalid state" we can get is infinity
, For example, if we divide by zero, we get these issues:
It is very easy to run into issues when using floating point! Another issue that can easily come up has to do with numerical error:
While our property-based tests went through fine in our case, we got lucky. We should really not be testing floats for equality, but for if they are close enough: