Skip to content

10.3: Special Topic: Dimension-Safe Matrix Multiplication

Last class, we saw one example extension to our type system: handling even and odd integers. This allows us to, for example, do the following:

let x = 1 in // TNat odd
let y = 2 in // TNat even

let b = ... in // bool
if b then x + y else y // TNat any

The main purpose of this extension is to guarantee that if e has type TNat even, then when e evaluates to a value, it must be an even integer (and similarly for TNat odd).

Today, we will show you another extension to the type system: this time, about matrices, which are ubiquitous in a number of computing contexts, including scientific computing and machine learning. Recall that a matrix is defined by a table of numbers of size \(n\) by \(m\), for some number of rows \(n\) and columns \(m\). The fundamental operation of matrices, for us, is matrix multiplication. Given an \(n\)-by-\(m\) matrix \(A\), and an \(m\)-by-\(k\) matrix \(B\), we define an \(n\)-by-\(k\) matrix \(AB\) like so:

\[(AB)_{x,y} = \sum_{i=1}^{m} A_{x,i} B_{i,y}\]

where \((AB)_{x, y}\) is the element of the matrix \(AB\) at row \(x\) and column \(y\).

Prelude: Matrix Multiplication in OCaml

Before thinking about how we would implement this in our language, let's think about how we would do so in OCaml. We could implement the matrix as a list of list of floats, along with its dimensions:

type matrix = {
    ncols : int;
    nrows : int;
    (* Stored as list of rows. The outer list should have nrows elements, 
       and the inner lists should have ncols elements. *)
    values : float list list 
}

Now, let's implement matrix multiplication. First, with some helper functions:

let sum_list (xs : float list) = 
  List.fold_left (+.) 0. xs

let build_matrix (n : int) (m : int) (f : int -> int -> float) : matrix = 
  {ncols = n; nrows = m; 
    values = 
    List.init n (fun i -> 
        List.init m (fun j -> 
        f i j
        )
    )
  }

let get_matrix (i : int) (j : int) (m : matrix) : float =
 List.nth (List.nth m.values i) j

Now, for the main event:

(* Only works if a.ncols = b.nrows *)
let rec mmult (a : matrix) (b : matrix)  : matrix = 
  build_matrix a.nrows b.ncols (fun i j -> 
      let products = List.init a.ncols (fun x -> get_matrix i x a *. get_matrix x j b) in 
      sum_list products
  )

Now, let's see some examples.

let m1 : matrix = 
    { nrows = 2; ncols = 3; 
        values = 
        [ 
            [ 1.0; 2.0; 0.5]; 
            [ 0.1; 0.2; 1.2]; 
        ]
    }

(* 3 by 4 matrix - 3 rows, 4 columns *)
let m2 : matrix = 
    { nrows = 3; ncols = 4; 
    values = 
    [ 
        [ 0.2; 0.3; 5.6; 2.1]; 
        [ 3.8; 0.1; 0.01; 0.48]; 
        [ 1.62; 9.2; 10.0; 5.6]; 
    ]

}

(* 2 by 4 matrix*)
let m3 = mmult m1 m2

(* Exception: Failure "nth" *)
let bad () = mmult m2 m1

By this point, you can see the problem. It is easy to get mmult to fail due to a dimension mismatch. What if you were implementing a complicated pipeline with many arrays of different sizes? It would be extremely easy to multipy the wrong matrices together. While we could have mmult return a matrix option instead of throwing an exception, that still doesn't really solve the problem. Even worse, if you wanted to implement a helper function:

let mmult_three (m1 : matrix) (m2 : matrix) (m3 : matrix) = 
  mmult m1 (mmult m2 m3)

Then you can't even be guaranteed that mmult_three won't break, because you have no guarantees about the values of m1, m2, and m3. While you could write the precondition (what must be true to call the function) as a comment:

(* Requires: m1.ncols = m2.nrows and m2.ncols = m3.nrows *)
let mmult_three (m1 : matrix) (m2 : matrix) (m3 : matrix) = 
  mmult m1 (mmult m2 m3)

There is no formal guarantee that this precondition will always be satisfied.

Today, we'll see a way to tackle this issue using types.

Extending our Language with Matrices

The way we will do so is first by taking our base language (from 10.1) and adding a type of matrices. First, we add it to the syntax and semantics:

type op =
  | Add
  | Mul
  | And
  | Or
  | LT
  | EQ
  | MMult (* Takes two arguments: both of type matrix *)


type value =
  | Nat  of int
  | Bool of bool
  | Matrix of matrix

type expr =
  | Value of value
  | Var   of string
  | Let   of string * expr * expr
  | ITE   of expr * expr * expr
  | App   of op * expr list
  | GenMat of int * int (* Expression for generating a matrix *)


module StringMap = Map.Make (String)

let eval_op (o : op) (vs : value list) : value = 
  match o, vs with 
  ...
  | (MMult, [Matrix m1; Matrix m2]) -> Matrix (mmult m1 m2)
  ...

let rec eval (env : value StringMap.t) (e : expr) : value = 
  match e with 
  ...
  | GenMat (n, m) -> Matrix (build_matrix n m (fun _ _ -> Random.float 1.0))
  ...

Now, let's think about the type. We could have this type:

type ty =
  | TNat
  | TBool
  | TMatrix 
and then the only new typing rules would be that typeof ty_env (GenMat (i, j)) = Some TMatrix, and typeof ty_env (App (MMult, [e1; e2])) = Some TMatrix as long as e1 and e2's type is TMatrix. But what's wrong with this?

It is unsound, because type soundness would fail. We could construct a well-typed program that would crash:

Let ("x", GenMat (5, 6),
  Let ("y", GenMat (3, 2),
    App (MMult, Var "x", Var "y")
  )
) 
This is in direct odds to type soundness, which requires that every well-typed program shouldn't crash.

Aside

Comparing our notion of type soundness to OCaml is subtle. In OCaml, you can make things "crash" while still being well-typed:

utop # 3 / 0;;
Exception: Division_by_zero.

However, since OCaml natively supports exceptions --- and indeed, you can actually recover from them:

utop # try 3 / 0 with | Division_by_zero -> 0;;
- : int = 0

An appropriate notion of type soundness for OCaml would allow for this possibility. Indeed, this isn't quite the evaluator of OCaml crashing; instead, it is the evaluator correctly throwing an exception. Said another way, utop itself doesn't crash! However, since our own language does not have exceptions --- and instead our evaluator itself fails --- we would consider this to be a type soundness issue. In a sense, this makes our type soundness guarantee a bit stronger than OCaml's.

How can we fix this? The trick --- like we saw last lecture with even and odd numbers --- is to make the type of matrices finer grained. We won't just have a type of "a matrix", but instead a type for "matrices with a certain dimension":

type ty =
  | TNat
  | TBool
  | TMatrix of int * int

Now, we just need to fix our rules for how to build and manipulate matrices. Importantly, no other parts of the type system needs to change; only the part we care about!

(* typeof otherwise identical, except: *)
typeof ty_env (GenMat (i, j)) = Some (TMatrix (i, j))
let type_of_op (op : op) (args : ty option list) : ty option =
  match (op, args) with
  (* Otherwise identical, except: *)
  ...
  | (MMult, [Some (TMatrix (i, j)); Some (TMatrix (k, l))]) -> 
    if j = k then Some (TMatrix (i, l)) else None

This allows us to recover type soundness. Why? Because if e has type TMatrix (n, m) and evaluates to a value v, we know that not only is it a matrix, but it has the correct dimension:

let interpret_ty (t : ty) (v : value) : bool = 
  match t with 
  | TNat -> (match v with | Nat _ -> true | _ -> false)
  | TBool -> (match v with | Bool _ -> true | _ -> false)
  | TMatrix (n, m) -> (match v with | Matrix mat -> mat.nrows = n && mat.ncols = m | _ -> false)
Exercise

How to add matrix addition? How to add scalar multiplication? How do perform the transpose?

Note that the above story for matrices works easily because the dimension of matrices are not values in the language, but rather OCaml integers (statically known values). Thus, the following wouldn't work:

let x = rand() in // Suppose we had a function rand to create a random TNat 
let m = GenMat(x, x) in // what would the type of m be?

To handle choices where the matrix dimension is dynamic in full generality, more complicated solutions exist.