Lecture 17: Type Inference
1 Type inference:   guessing correctly, every time
1.1 Example 1
1.2 Example 2
2 Type inference implementation
2.1 Preparation
2.2 Unification
2.3 Inference and Instantiation
2.4 Inference and Generalization
3 Putting it all together
3.1 Challenges and hints
8.12

Lecture 17: Type Inference🔗

In the last lecture, we designed a system for categorizing the expressions in our programs with types, and created an algorithm for checking whether our programs had a self-consistent type. However, that system ran aground in terms of usability, when it came to primitives like isbool or ==, which accepted arguments of “any type” and produced boolean answers, or worse yet, print, which took arguments of “any type” and produced answers of that type. In order to handle this polymorphism, we had to resort to hacks: either explicitly annotating the program everywhere there was ambiguity, or splitting the primitives into separate, monomorphic versions that all behaved the same way. Can we do better?

1 Type inference: guessing correctly, every time🔗

The key observation we made last time was that our type checker must be syntax-directed, meaning at every expression of our program, we had enough local information to know how to recur on each subexpression, and we never had to “guess and check” our results. However, when we write down a polymorphic type and informally use it, we walk through a reasoning process that says “In this case, assume that this type variable means this particular type, then check that everything is self-consistent.” In other words, we mentally perform a kind of substitution of types for type-variables, then continue with the same type-checking algorithm as before.

But this is in fact overkill. The self-consistency requirement alone is enough to type-check our program, if we exploit it correctly. The key is to reframe the question from “Does this program have this particular type?” to “Does there exist a type for this program such that the program is self-consistent?” We may not know what that type is, yet, but we can deduce certain facts about it. This process is called type-inference, and we will follow the classic Hindley-Milner algorithm to deduce types for our programs.

The trick to getting started is to say, “Sure! Every expression can have a type – let’s just make up a fresh type variable 'X and claim that as the type.” Of course, there may be constraints on 'X based on how we use the value in our program; these are called type constraints, and type inference is simply the process of collecting and solving these constraints. The easiest way to see this is by example.

In order to make our type inference tractable, we are going to split our notion of types and type schemes. Intuitively, a type is a single thing, while a type scheme can quantify over type variables. Our goal will be to infer a type scheme for each function in our program, and a type for each expression.1When we get to higher-order languages, this split is not quite so straightforward, and we’ll need to revisit it slightly.

1.1 Example 1🔗

Suppose we have the program

def f(x):
  x + 6

f(38)

Let’s start by inferring a type for f. We have no annotations, so we simply create new type variables as needed, and assert that f has type 'T1 -> 'T2, for two unknown type variables 'T1 and 'T2. We begin to infer a type for the body of f by binding the arguments to their types in our type environment – here, x : 'T1.

We recur into the body of f, and examine x + 6. This is an EPrim2 expression, so we look up the type scheme for Plus, and see that it is Forall [], (Int, Int -> Int). We look at the arguments to the operator, and infer their types. We look up x in the environment and obtain 'T1, and determine that all numbers have type Int. We know that this operation will produce some result type, so we make up a new type variable 'prim2result. We can now construct a type ('T1, Int -> 'prim2result), and unify it with the type (Int, Int -> Int). This produces a substitution ['T1 = Int, 'prim2result = Int] We apply that substitution, and deduce that x : Int. The result of our expression has type 'prim2result, and since it is the entirety of the body of f, we unify that with the result type 'T2, making sure to preserve the substitution where ['prim2result = Int]. Our net result is f : Int -> Int, which we can generalize to the type scheme Forall [], Int -> Int, and update our function environment accordingly.

We next need to infer types for f(38). We repeat a similar process: we lookup the type scheme for f, we infer a type for 38, and make up a new result type 'app_result, and unify Int -> Int (which is the type of f) with Int -> 'app_result. We deduce that ['app_result = Int], and this is the final type of our program.

1.2 Example 2🔗

Suppose we have the following trickier program:

def f(x, y):
  isnum(print(x)) && isbool(y)

def g(z):
  f(z, 5)

g(7)

Since our functions are not mutually recursive, we can handle them independently.

We can repeat the process for g, guessing an initial type of 'T7 -> 'T8. When we get to the function call, we instantiate the type for f to f : ('T9, 'T10 -> Bool), and proceed. We wind up unifying ['T7 = 'T9, 'T10 = Int, 'T8 = Bool], and generalizing, we get a final type scheme of g : Forall ['T9], ('T9 -> Bool).

Finally we proceed to g(7). Again we instantiate the scheme for g, unify the result with Int, and obtain a final type of Bool. Our progam is self-consistent.

2 Type inference implementation🔗

2.1 Preparation🔗

To implement type inference ourselves, we’ll need lots of preparatory definitions:

Exercise

Do these.

Do Now!

You are welcome to clean this up and make it more efficient, by passing along a StringSet.t as an accumulator parameter, of all the free type variables seen so far.

2.2 Unification🔗

The central operation of type inference is unification, which requests that we make two types become equal by computing a substitution for their free type variables under which the types really become identical:

let rec unify (t1 : 'a typ) (t2 : 'a typ) (* other bookkeeping *) : 'a typ subst = ...

For example, we can unify 'X -> Int and Bool -> 'Y under the substitution ['X = Bool, 'Y = Int]. If we apply this substitution to both types, they both become Bool -> Int and are identical.

However, there is no substitution that can unify Int -> 'X with Bool -> 'Y. Even if we set ['X = 'Y], we cannot unify Bool with Int. We get a unification error, which we report as a type error.

Our goal is to produce as many substitutions as necessary: each substitution asserts that a particular type variable must be equal to a particular type. This is the key behind Hindley-Milner type inference. We can always guess a new type variable, and unification will steadily constrain that guess with more and more equality constraints until we have a final answer, or a contradiction.

Note that we have to be careful in unifying too recklessly. We cannot unify 'A with 'A -> 'B, because we would get the absurd substitution ['A = 'A -> 'B]. It is immediately clear that these two types cannot be equal: the left side has “one fewer arrows” than the right side. To rule this out, we must perform an occurs check, to see whether the variable we’re trying to constrain appears within its constraint:

let occurs (name : string) (t : 'a typ) =
  StringSet.mem name (ftv_type t)

Our unification proceeds by cases:

Exercise

In the implementation of unify, it is quite helpful to pass around diagnostic information, such as a sourcespan of the expression that triggered the unification request (and possibly more information) so that your error message can direct the programmer to where the problem occurred. Consider what information might be useful, and extend the signature and implementation of unify accordingly.

2.3 Inference and Instantiation🔗

Our goal in inferring types for our program is to compute a type for all subexpressions, for which the program is entirely self-consistent. To do this, our signature will benefit from returning not only the computed type, but also the substitution that we’ve accumulated so far.2If we had monadic notation, as in Haskell, this syntactic overhead could be avoided. Our primary function for type-inference of expressions will be

let rec infer_exp (funenv : sourcespan scheme envt) (* environment for functions *)
                  (env : sourcespan typ envt) (* environment for variables *)
                  (e : sourcespan expr)
                  (* ... some bookkeeping ... *)
                  : ( sourcespan typ subst (* the resulting substitution *)
                    * sourcespan typ)      (* the resulting type *) =
  match e with
  ...

Suppose we construct two base types,

let tInt = TyCon("Int", dummy_sourcespan)
let tBool = TyCon("Bool", dummy_sourcespan)

Then the first two cases of infer_exp are easy:

let rec infer_exp funenv env e ... =
 match e with
  | ENumber _ -> ([], tInt)
  | EBool _ -> ([], tBool)
  ...

Variables can be handled by looking them up in env. But how should we handle things like EPrim1(Print, e, _)? After all, polymorphic operators like these were why we built this whole new infrastructure!

The key observation is that print has the type scheme Forall 'X, ('X -> 'X). Therefore, it has the type foo -> foo, for any type foo we want. So our key step is simply to create a fresh new type variable, one that’s never been used before in our program, and substitute it into the type scheme.

let gensym =
  let count = ref 0 in
  let next () =
    count := !count + 1;
    !count
  in fun str -> sprintf "%s_%d" str (next ());;
let instantiate (s : 'a scheme) : 'a typ = ...

Define a function instantiate that accomplishes this specialization step. It will use the gensym function to create any necessary new type variables.3Yes, we’re using a stateful gensym instead of pure tags. Since we don’t know a priori how many times we’ll instantiate a type scheme, it’s hard to know in advance what tags to even use.

We can now use this function to infer types for prim1 and prim2 constructs. Populate the initial funenv with type schemes for all the primitives by giving them unique names. Each primitive should have some arrow type, to record its argument type(s) and return type. Then, in the EPrim1 case:

A similar process applies for function-application expressions.

For conditionals, recursively infer types for the condition and the two branches. Then unify the condition’s type with tBool, and unify the two branches’ types with each other. The resulting type is the type of a branch, and the resulting substitution is the composition of five smaller substitutions: the three subexpressions, plus the two unifcation calls:

let rec infer_exp funenv env e ... =
 match e with
   ...
 | EIf(c, t, f, _) ->
   (* After each inference, update the type environment *)
   let (c_subst, c_typ) = infer_exp funenv env c ... in
   let env = apply_subst_env c_subst env in
   let (t_subst, t_typ) = infer_exp funenv env t ... in
   let env = apply_subst_env t_subst env in
   let (f_subst, f_typ) = infer_exp funenv env f ... in
   let env = apply_subst_env f_subst env in
   (* Compose the substitutions together *)
   let subst_so_far = compose_subst (compose_subst c_subst t_subst) f_subst in
   (* rewrite the types *)
   let c_typ = apply_subst_typ subst_so_far c_typ in
   let t_typ = apply_subst_typ subst_so_far t_typ in
   let f_typ = apply_subst_typ subst_so_far f_typ in
   (* unify condition with Bool *)
   let unif_subst1 = unify c_typ tBool ... in
   (* unify two branches *)
   let unif_subst2 = unify t_typ f_typ ... in
   (* compose all substitutions *)
   let final_subst = compose_subst (compose_subst subst_so_far unif_subst1) unif_subst2 in
   let final_typ = apply_subst_typ final_subst t_typ in
   (final_subst, final_typ)
 | ...

Handling let-bindings requires careful management of the environment. Infer a type for the first binding. Add the variable and that type to the environment, and recursively process the remaining bindings. When no bindings remain, process the body. The resulting type is the type of the body, and the resulting substitution is the composition of all the intermediate substitutions.

2.4 Inference and Generalization🔗

It stands to reason that if using a function (in a function call) instantiates its type scheme, then defining a function must somehow generalize the type of the body into a type scheme. This process is straightforward: collect all the free type variables in the type of the function body, subtract away any type variables that appear free in the type environment, and whatever is leftover can be generalized into a type scheme.

Exercise

Define a function

let generalize (env : 'a typ envt) (t : 'a typ) : 'a scheme = ...

that accomplishes this task.

The challenges here mostly have to do with when to generalize which function bodies. To handle this, we will collect function definitions into groups of mutually-recursive definitions. To infer type schemes for a definition group, instantiate the type schemes for the functions all at once. Infer types for each function body, and accumulate the substitutions that result. Finally, generalize all the remaining types all at once.

Exercise

Define two functions

let infer_decl funenv env (decl : sourcespan decl) (* ... bookkeeping ... *)
    : (sourcespan scheme envt * sourcespan typ) = ...

let infer_group funenv env (group : sourcespan decl list) (* ... bookkeeping ...*)
    : sourcespan scheme envt =

to accomplish this. In infer_decl, you should assume that the function you’re inferring has already been instantiated to a (potentially unknown) monomorphic type, while in infer_group, you should do the instantiation and generalization.

Consider the following two programs:

def f(x): # should have scheme Forall 'X, ('X -> 'X)
  print(x)

def ab_bool(a, b): # should have scheme Forall 'A, 'B, ('A, 'B -> Bool)
  isnum(f(a)) && f(b)

ab_bool(3, true) && ab_bool(true, false)

versus

def f(x):
  print(x)

and def ab_bool(a, b):
  isnum(f(a)) && f(b) # ???

ab_bool(3, true) && ab_bool(true, false)

In this version of the program, since f and ab_bool are mutually recursive, f can only have one type within the body of ab_bool, and so the line with the question marks cannot typecheck.

On the other hand, consider a simple pair of mutually recursive functions:

def even(n):
  not(odd(n))

and def odd(n):
  if n == 0: false
  else: if n == 1: true
  else:
    even(n - 1)

odd(5)

If these two functions were not mutually recursive (marked by the and keyword), then even would have type Forall 'X, 'X -> Bool, since it never uses n in any particular way.

3 Putting it all together🔗

We’re nearly finished. Define a function

let infer_prog (funenv : sourcespan scheme envt) (env : sourcespan typ envt)
               (p : sourcespan program) : sourcespan program = ...

that either returns the program if it successfully infers a type, or raises an exception if it fails. Finally, define

let type_synth (p : sourcespan program) : sourcespan program fallible =
  try
    Ok(infer_prog initial_env StringMap.empty p)
  with e -> Error([e])
;;

And we’re done!

3.1 Challenges and hints🔗

At the end of every inference call, be sure to apply the resulting substitution to the resulting type, to ensure that the types are as fully-rewritten as possible. This helps ensure that we only need to apply each substitution once to get the correct results.

Also be sure never to reuse type variables. When in doubt, gensym an extra one and unify it wherever needed. But be sure never to discard the resulting substitutions!

Be careful with the occurs check. Right now it won’t impact us, because we don’t have higher-order functions or other type constructors, but once we do, this can be the source of some very subtle bugs.

Test thoroughly! You will need to write larger programs than before, to test how unification constraints are propagated throughout your inference process.

Avail yourself of debug-printing output, at many stages of the process. You will always want more information than you have, in order to debug why something went wrong during inference. Once you’ve implemented this inference process, you should have a much greater sympathy for compiler authors who provide inscrutable unification error messages! Try to record as many bits of debugging information as possible.

1When we get to higher-order languages, this split is not quite so straightforward, and we’ll need to revisit it slightly.

2If we had monadic notation, as in Haskell, this syntactic overhead could be avoided.

3Yes, we’re using a stateful gensym instead of pure tags. Since we don’t know a priori how many times we’ll instantiate a type scheme, it’s hard to know in advance what tags to even use.