Skip to content

Lecture 3.3: More PBT

Last class, we talked about the correctness of sorting algorithms. This class, we will see another case study: the correctness of traversals of binary trees.

Let's take a look at a type of binary tree in OCaml:

type entry = {
    key : string;
    value : int
}

type bintree = 
    | Leaf of entry
    | Branch of bintree * bintree

This binary tree, at its leaves, stores two things: a key, given by a string; and a value, given by an int. We store these two entries in a record called entry for convenience. As an example, this binary tree could be storing account balances: the key is the account name, while the value is the amount of money in the account.

If we are the bank, we could be concerned with how much money is in all accounts that start with "A" (think of "A" here as a tag of a certain kind of account). We could do it directly:

let is_tagged (s : string) = 
    if String.length s >= 1 then 
        String.get s 0 = 'A'
    else false

let rec sum_tagged (t : bintree) = 
    match t with
    | Leaf e -> if is_tagged e.key then e.value else 0
    | Branch (t1, t2) -> sum_tagged t1 + sum_tagged t2
(Above, String.get has type string -> int -> char, where char is the type of a single character, such as 'A'. )

While this works, we are finding that it is very slow, because the total set of accounts is large (say, 20 gigabytes). However, you noticed that the set of accounts tagged with 'A' is actually small: maybe only a few thousand! This gives you an idea: let's precompute the 'A' accounts, store them to the side, and then sum those up separately.

let rec collect_tagged (t : bintree) : entry list = 
    match t with 
    | Leaf e -> if is_tagged e.key then [e] else []
    | Branch (t1, t2) -> collect_tagged t1 @ collect_tagged t2

let rec sum_collected (xs : entry list) = 
    match xs with
    | [] -> 0
    | x::xs' -> x.value + sum_collected xs'

While we are still traversing the whole data structure during collect_tagged, we only need to do this once. Once we have the list of entries tagged with 'A', we can then cheaply compute statistics with them (such as the sum), by storing this list of entries on disk.

How can we reason about the correctness of our optimization? We need to specify it, by showing that the optimized code is equivalent to the original:

\[ \forall t \in \texttt{bintree}, \texttt{sum_tagged}(t) = \texttt{sum_collected}(\texttt{collect_tagged}(t)). \]

Now that we have a quantified formula we want the truth value of, we can use PBT to see if it's actually true. The first step of applying PBT is to turn the universal quantifier into a random sampling, so let's write a generator for bintree:

let tags = ["A_alice"; "B_dave"; "C_carl"; "A_amy"; "B_bob"; "A_anna"; "C_cara"; "A_arthur"; "B_brian"; "C_cindy"]

let gen_entry () = 
    let k = List.nth tags (Random.int (List.length tags)) in 
    let v = Random.int 500 in 
    { key = k; value = v }

let rec gen_bintree (n : int) = fun () -> 
    if n <= 0 then Leaf (gen_entry ()) else 
     if Random.int 10 < 7 then 
        Branch (gen_bintree (n - 1) (), gen_bintree (n - 1) ())
    else Leaf (gen_entry ())

The generator we wrote for gen_bintree is a little smarter than a previous generator we've seen for binary trees. Instead of generating balanced binary trees, this generator will generate imbalanced ones (where one branch might hold more leaves than another). We do this by rolling a die from 0..9, and asking if the result is less than 7. Thus, 8/10 times, we will continue with the branch. But 2/10 times, we will stop. Thus, if we had code that only worked with balanced trees --- but did something wrong with an imbalanced one -- we would catch this bug.

Next, we want to implement our specification, for each t:

let sum_spec (t : bintree) = 
    sum_tagged t = sum_collected (collect_tagged t)

Finally, we will pull in the logic from last class for repeatedly calling the sampler to evaluate the specification:

let rec forall_bintree (gen : unit -> bintree) (prop : bintree -> bool) (trials : int) : bintree option =
    if trials <= 0 then     
        None (* No counter-example found *)
    else 
        let xs = gen () in 
        if prop xs then 
            forall_bintree gen prop (trials - 1)
        else Some xs

At this point, we can see that our specification actually holds:

utop # forall_bintree (gen_bintree 8) sum_spec 100;;
- : bintree option = None

Working with averages, floating point

Great! Now, let's do a similar spec for computing the average.

let rec average_tagged (t : bintree) : float = 
    match t with
    | Leaf e -> if is_tagged e.key then float_of_int e.value else 0.0
    | Branch (t1, t2) -> (average_tagged t1 +. average_tagged t2) /. 2.0

let average_collected (xs : entry list) =
    if List.length xs = 0 then 0.0 else 
        (float_of_int (sum_collected xs)) /. (float_of_int (List.length xs))

In the first function, we compute the average recursively: the average of a leaf is the leaf itself, while the average of a branch is computed by recursively computing the branch averages, and averaging those. In the second function, we flatten the tree out and compute the average on the flattened tree. We are careful to avoid dividing by zero, by treating that as a special case.

Is this right? Let's see:

utop # forall_bintree (gen_bintree 1) (fun t -> average_tagged t = average_collected (collect_tagged t)) 32;;
- : bintree option =
Some
 (Branch (Leaf {key = "A_alice"; value = 277},
          Leaf {key = "B_dave"; value = 252}))

Oops! What went wrong? The way we are recursively computing the average is dividing by two even if there are no tags on one of the two subtrees, while the flattened computation filters out to only contain the entries tagged with 'A'. How can we fix average_tagged to be correct?

let rec average_tagged' (t : bintree) : float =
    let rec sum_and_count (t : bintree) : int * int =
        match t with
        | Leaf e -> if is_tagged e.key then (e.value, 1) else (0, 0)
        | Branch (t1, t2) ->
            let (s1, c1) = sum_and_count t1 in
            let (s2, c2) = sum_and_count t2 in
            (s1 + s2, c1 + c2)
    in
    let (sum, count) = sum_and_count t in
    if count = 0 then 0.0 else (float_of_int sum) /. (float_of_int count)

We will compute the sum recursively along with a count of how many elements we want to average over. We can see if this works:

utop # forall_bintree (gen_bintree 8) (fun t -> average_collected (collect_tagged t) = average_tagged' t) 32;;
- : bintree option = None

Indeed, it does!

A Note about floating point

An earlier version of the code for this lecture forgot to check if we were dividing by zero:

let average_collected (xs : entry list) =
        (float_of_int (sum_collected xs)) /. (float_of_int (List.length xs))

This caused issues with our final theorem about average_tagged'. Let's see what went wrong:

utop # let t = (Branch (Leaf {key = "C_cindy"; value = 446},
                        Leaf {key = "B_brian"; value = 154}));;
val t : bintree = ...

utop # average_tagged' t;;
- : float = 0.

utop# average_collected (collect_tagged t);;
- : float = nan

While average_tagged' returned 0. (as expected), we see that the flattened version returned NaN (or not a number), which happens when we do an invalid operation. The other "invalid state" we can get is infinity, For example, if we divide by zero, we get these issues:

utop # 3.0 /. 0.0;;
- : float = infinity

utop # 0.0 /. 0.0;;
- : float = nan

It is very easy to run into issues when using floating point! Another issue that can easily come up has to do with numerical error:

utop # let x = 0.1 +. 0.2;;
- : float = 0.300000000000000044

utop # x = 0.3 ;;
- : bool = false

While our property-based tests went through fine in our case, we got lucky. We should really not be testing floats for equality, but for if they are close enough:

let equal_float (x : float) (y : float) = abs_float (x -. y) < 1e-10