10.3: Special Topic: Dimension-Safe Matrix Multiplication
Last class, we saw one example extension to our type system: handling even and odd integers. This allows us to, for example, do the following:
let x = 1 in // TNat odd
let y = 2 in // TNat even
let b = ... in // bool
if b then x + y else y // TNat any
The main purpose of this extension is to guarantee that if e has type TNat even, then when e evaluates to a value, it must be an even integer (and similarly for TNat odd).
Today, we will show you another extension to the type system: this time, about matrices, which are ubiquitous in a number of computing contexts, including scientific computing and machine learning. Recall that a matrix is defined by a table of numbers of size \(n\) by \(m\), for some number of rows \(n\) and columns \(m\). The fundamental operation of matrices, for us, is matrix multiplication. Given an \(n\)-by-\(m\) matrix \(A\), and an \(m\)-by-\(k\) matrix \(B\), we define an \(n\)-by-\(k\) matrix \(AB\) like so:
where \((AB)_{x, y}\) is the element of the matrix \(AB\) at row \(x\) and column \(y\).
Prelude: Matrix Multiplication in OCaml
Before thinking about how we would implement this in our language, let's think about how we would do so in OCaml. We could implement the matrix as a list of list of floats, along with its dimensions:
type matrix = {
ncols : int;
nrows : int;
(* Stored as list of rows. The outer list should have nrows elements,
and the inner lists should have ncols elements. *)
values : float list list
}
Now, let's implement matrix multiplication. First, with some helper functions:
let sum_list (xs : float list) =
List.fold_left (+.) 0. xs
let build_matrix (n : int) (m : int) (f : int -> int -> float) : matrix =
{ncols = n; nrows = m;
values =
List.init n (fun i ->
List.init m (fun j ->
f i j
)
)
}
let get_matrix (i : int) (j : int) (m : matrix) : float =
List.nth (List.nth m.values i) j
Now, for the main event:
(* Only works if a.ncols = b.nrows *)
let rec mmult (a : matrix) (b : matrix) : matrix =
build_matrix a.nrows b.ncols (fun i j ->
let products = List.init a.ncols (fun x -> get_matrix i x a *. get_matrix x j b) in
sum_list products
)
Now, let's see some examples.
let m1 : matrix =
{ nrows = 2; ncols = 3;
values =
[
[ 1.0; 2.0; 0.5];
[ 0.1; 0.2; 1.2];
]
}
(* 3 by 4 matrix - 3 rows, 4 columns *)
let m2 : matrix =
{ nrows = 3; ncols = 4;
values =
[
[ 0.2; 0.3; 5.6; 2.1];
[ 3.8; 0.1; 0.01; 0.48];
[ 1.62; 9.2; 10.0; 5.6];
]
}
(* 2 by 4 matrix*)
let m3 = mmult m1 m2
(* Exception: Failure "nth" *)
let bad () = mmult m2 m1
By this point, you can see the problem. It is easy to get mmult to fail due to a dimension mismatch.
What if you were implementing a complicated pipeline with many arrays of different sizes? It would be extremely easy to multipy the wrong matrices together.
While we could have mmult return a matrix option instead of throwing an exception, that still doesn't really solve the problem.
Even worse, if you wanted to implement a helper function:
Then you can't even be guaranteed that mmult_three won't break, because you have no guarantees about the values of m1, m2, and m3.
While you could write the precondition (what must be true to call the function) as a comment:
(* Requires: m1.ncols = m2.nrows and m2.ncols = m3.nrows *)
let mmult_three (m1 : matrix) (m2 : matrix) (m3 : matrix) =
mmult m1 (mmult m2 m3)
There is no formal guarantee that this precondition will always be satisfied.
Today, we'll see a way to tackle this issue using types.
Extending our Language with Matrices
The way we will do so is first by taking our base language (from 10.1) and adding a type of matrices. First, we add it to the syntax and semantics:
type op =
| Add
| Mul
| And
| Or
| LT
| EQ
| MMult (* Takes two arguments: both of type matrix *)
type value =
| Nat of int
| Bool of bool
| Matrix of matrix
type expr =
| Value of value
| Var of string
| Let of string * expr * expr
| ITE of expr * expr * expr
| App of op * expr list
| GenMat of int * int (* Expression for generating a matrix *)
module StringMap = Map.Make (String)
let eval_op (o : op) (vs : value list) : value =
match o, vs with
...
| (MMult, [Matrix m1; Matrix m2]) -> Matrix (mmult m1 m2)
...
let rec eval (env : value StringMap.t) (e : expr) : value =
match e with
...
| GenMat (n, m) -> Matrix (build_matrix n m (fun _ _ -> Random.float 1.0))
...
Now, let's think about the type. We could have this type:
and then the only new typing rules would be thattypeof ty_env (GenMat (i, j)) = Some TMatrix, and
typeof ty_env (App (MMult, [e1; e2])) = Some TMatrix as long as e1 and e2's type is TMatrix.
But what's wrong with this?
It is unsound, because type soundness would fail. We could construct a well-typed program that would crash:
This is in direct odds to type soundness, which requires that every well-typed program shouldn't crash.Aside
Comparing our notion of type soundness to OCaml is subtle. In OCaml, you can make things "crash" while still being well-typed:
However, since OCaml natively supports exceptions --- and indeed, you can actually recover from them:
An appropriate notion of type soundness for OCaml would allow for this possibility. Indeed, this isn't quite the evaluator of OCaml crashing; instead, it is the evaluator correctly throwing an exception. Said another way, utop itself doesn't crash! However, since our own language does not have exceptions --- and instead our evaluator itself fails --- we would consider this to be a type soundness issue. In a sense, this makes our type soundness guarantee a bit stronger than OCaml's.
How can we fix this?
The trick --- like we saw last lecture with even and odd numbers --- is to make the type of matrices finer grained. We won't just have a type of "a matrix", but instead a type for "matrices with a certain dimension":
Now, we just need to fix our rules for how to build and manipulate matrices. Importantly, no other parts of the type system needs to change; only the part we care about!
let type_of_op (op : op) (args : ty option list) : ty option =
match (op, args) with
(* Otherwise identical, except: *)
...
| (MMult, [Some (TMatrix (i, j)); Some (TMatrix (k, l))]) ->
if j = k then Some (TMatrix (i, l)) else None
This allows us to recover type soundness. Why? Because if e has type TMatrix (n, m) and evaluates to a value v, we know that not only is it a matrix, but it has the correct dimension:
let interpret_ty (t : ty) (v : value) : bool =
match t with
| TNat -> (match v with | Nat _ -> true | _ -> false)
| TBool -> (match v with | Bool _ -> true | _ -> false)
| TMatrix (n, m) -> (match v with | Matrix mat -> mat.nrows = n && mat.ncols = m | _ -> false)
Exercise
How to add matrix addition? How to add scalar multiplication? How do perform the transpose?
Note that the above story for matrices works easily because the dimension of matrices are not values in the language, but rather OCaml integers (statically known values). Thus, the following wouldn't work:
let x = rand() in // Suppose we had a function rand to create a random TNat
let m = GenMat(x, x) in // what would the type of m be?
To handle choices where the matrix dimension is dynamic in full generality, more complicated solutions exist.