8.10

## Lecture 17: Type Inference

In the last lecture, we designed a system for categorizing the expressions in our programs with types, and created an algorithm for checking whether our programs had a self-consistent type. However, that system ran aground in terms of usability, when it came to primitives like isbool or ==, which accepted arguments of “any type” and produced boolean answers, or worse yet, print, which took arguments of “any type” and produced answers of that type. In order to handle this polymorphism, we had to resort to hacks: either explicitly annotating the program everywhere there was ambiguity, or splitting the primitives into separate, monomorphic versions that all behaved the same way. Can we do better?

### 1Type inference: guessing correctly, every time

The key observation we made last time was that our type checker must be syntax-directed, meaning at every expression of our program, we had enough local information to know how to recur on each subexpression, and we never had to “guess and check” our results. However, when we write down a polymorphic type and informally use it, we walk through a reasoning process that says “In this case, assume that this type variable means this particular type, then check that everything is self-consistent.” In other words, we mentally perform a kind of substitution of types for type-variables, then continue with the same type-checking algorithm as before.

But this is in fact overkill. The self-consistency requirement alone is enough to type-check our program, if we exploit it correctly. The key is to reframe the question from “Does this program have this particular type?” to “Does there exist a type for this program such that the program is self-consistent?” We may not know what that type is, yet, but we can deduce certain facts about it. This process is called type-inference, and we will follow the classic Hindley-Milner algorithm to deduce types for our programs.

The trick to getting started is to say, “Sure! Every expression can have a type – let’s just make up a fresh type variable 'X and claim that as the type.” Of course, there may be constraints on 'X based on how we use the value in our program; these are called type constraints, and type inference is simply the process of collecting and solving these constraints. The easiest way to see this is by example.

In order to make our type inference tractable, we are going to split our notion of types and type schemes. Intuitively, a type is a single thing, while a type scheme can quantify over type variables. Our goal will be to infer a type scheme for each function in our program, and a type for each expression.1When we get to higher-order languages, this split is not quite so straightforward, and we’ll need to revisit it slightly.

#### 1.1Example 1

Suppose we have the program

def f(x):
x + 6

f(38)

Let’s start by inferring a type for f. We have no annotations, so we simply create new type variables as needed, and assert that f has type 'T1 -> 'T2, for two unknown type variables 'T1 and 'T2. We begin to infer a type for the body of f by binding the arguments to their types in our type environment – here, x : 'T1.

We recur into the body of f, and examine x + 6. This is an EPrim2 expression, so we look up the type scheme for Plus, and see that it is Forall [], (Int, Int -> Int). We look at the arguments to the operator, and infer their types. We look up x in the environment and obtain 'T1, and determine that all numbers have type Int. We know that this operation will produce some result type, so we make up a new type variable 'prim2result. We can now construct a type ('T1, Int -> 'prim2result), and unify it with the type (Int, Int -> Int). This produces a substitution ['T1 = Int, 'prim2result = Int] We apply that substitution, and deduce that x : Int. The result of our expression has type 'prim2result, and since it is the entirety of the body of f, we unify that with the result type 'T2, making sure to preserve the substitution where ['prim2result = Int]. Our net result is f : Int -> Int, which we can generalize to the type scheme Forall [], Int -> Int, and update our function environment accordingly.

We next need to infer types for f(38). We repeat a similar process: we lookup the type scheme for f, we infer a type for 38, and make up a new result type 'app_result, and unify Int -> Int (which is the type of f) with Int -> 'app_result. We deduce that ['app_result = Int], and this is the final type of our program.

#### 1.2Example 2

Suppose we have the following trickier program:

def f(x, y):
isnum(print(x)) && isbool(y)

def g(z):
f(z, 5)

g(7)

Since our functions are not mutually recursive, we can handle them independently.

• Start by guessing type variables for f, namely ('T1, 'T2 -> 'T3).

• Next, we bind the arguments in our environment,

[x : 'T1, y : 'T2], and recur into the body.

• Again we encounter a EPrim2, so we lookup the type scheme for the operator and get && : Forall [], (Bool, Bool -> Bool). We then infer types for both arguments:

• Our first argument is a EPrim1, so we look up the type scheme for the operator and find isnum : Forall ['X], ('X -> Bool). We cannot use this directly, so we instantiate the type by creating new type variables and substituting them: isnum : ('T4 -> Bool). We recur on the argument:

• Our argument is another EPrim1, so we look up the type scheme for the operator and find print : Forall ['X], ('X -> 'X). Again we instantiate the type to get print : ('T5 -> 'T5). Then we recur on the argument, and look it up to find x : 'T1. We create a return type 'prim1result1, and unify 'T1 -> 'prim1result1 with 'T5 -> 'T5, to obtain the substitution ['T1 = 'T5, 'prim1result1 = 'T5]. Our result type is 'prim1result1.

• We create a new result type for this primitive, and unify 'prim1result1 -> 'prim1result2 with 'T4 -> Bool, to obtain the substitution ['prim1result1 = 'T4, 'prim1result2 = Bool]. We combine that with the existing substitution to get ['T1 = 'T5, 'prim1result1 = 'T5, 'T4 = 'T5, 'prim1result2 = Bool], and return a result of 'prim1result2.

• Our second argument is another EPrim1, and the type inference for this is analogous to that for isnum. We look up the type scheme, instantiate it to isbool : 'T6 -> Bool, and infer a type for the nested expression. We look that expression up to obtain y : 'T2, generate a new result type 'prim1result3, and unify 'T2 -> 'prim1result3 with 'T6 -> Bool, to obtain ['T2 = 'T6, 'prim1result3 = Bool].

• We instantiate the type of the operator to (Bool, Bool -> Bool). We create a result type 'prim2result1, and unify the inferred type ('prim1result1, 'prim1result3 -> 'prim2result1) with (Bool, Bool -> Bool), to get ['prim1result1 = Bool, 'prim1result3 = Bool, 'prim2result1 = Bool]. We combine this with all our prior substitutions, and return the result type 'prim2result1.

• We unify 'prim2result1 with 'T3, then apply our overall substitution, and we obtain ['T1 = 'T5, 'T2 = 'T6, 'T3 = Bool] (along with other substitutions). Finally, we generalize this type, to get the final type scheme f : Forall ['T5, 'T6], ('T5, 'T6 -> Bool).

We can repeat the process for g, guessing an initial type of 'T7 -> 'T8. When we get to the function call, we instantiate the type for f to f : ('T9, 'T10 -> Bool), and proceed. We wind up unifying ['T7 = 'T9, 'T10 = Int, 'T8 = Bool], and generalizing, we get a final type scheme of g : Forall ['T9], ('T9 -> Bool).

Finally we proceed to g(7). Again we instantiate the scheme for g, unify the result with Int, and obtain a final type of Bool. Our progam is self-consistent.

### 2Type inference implementation

#### 2.1Preparation

To implement type inference ourselves, we’ll need lots of preparatory definitions:

Exercise

Do these.

• Define a type environment as a mapping from names to types, and a type scheme environment as a mapping from names to type schemes. Our initial type scheme environment will contain type schemes for all our primitives. Our initial type environment will be empty. We will use ML’s Maps, rather than association lists, because they are more efficient and our environments are unordered.

type 'a envt = 'a StringMap.t

• Define the notion of substituting a type for a type variable, within another type. This converts each occurrence of the given type variable tyvar into the desired type to_typ in the target type in_typ.

let rec subst_var_typ ((tyvar : string), (to_typ : 'a typ)) (in_typ : 'a typ) : 'a typ = ...

• Define the notion of substituting a type for a type variable, within a type scheme. This is the same idea as above, except that if the type variable is quantified by the scheme, we do not substitute. (This is exactly analogous to capture-avoiding substitution of expressions for variables, underneath lambda or let bindings.)

let rec subst_var_scheme ((tyvar : string), (to_typ : 'a typ))
(in_scheme : 'a scheme) : 'a scheme = ...

• Define a type substitution as an ordered mapping from names to types. (We use lists here to ensure an ordering, if it is needed.)

type 'a subst = (string * 'a) list

• Define a related set of functions that apply a substitution to a type, a type scheme, or environments, by applying each individual substitution from left-to-right. You may not need all of these functions, but they’re good to have available.

let apply_subst_typ    (subst : 'a typ subst) (t : 'a typ) : 'a typ = ...
let apply_subst_scheme (subst : 'a typ subst) (s : 'a scheme) : 'a scheme = ...
let apply_subst_typenv (subst : 'a typ subst) (env : 'a typ envt) : 'a typ envt = ...
let apply_subst_schenv (subst : 'a typ subst) (env : 'a scheme envt) : 'a scheme envt = ...
let apply_subst_substs (subst : 'a typ subst) (sub : 'a typ subst) : 'a typ subst = ...

• Define a function to combine two substitutions, by first applying one to the other, then concatenating the two.

let compose_subst (sub1 : 'a typ subst) (sub2 : 'a typ subst) : 'a typ subst =
sub1 @ (apply_subst_subst sub1 sub2)

• Define the free type variables of a type as the set of any type variables that appear within it. Define the free type variables of a type scheme as the type variables that appear un-quantified within the scheme. Finally, define the free type variables of a type environment as simply the union of all the type variables in all the types within it.

let rec ftv_type (t : 'a typ) : StringSet.t =
match t with
| TyCon _ -> StringSet.empty
| TyVar(name, _) -> StringSet.singleton name
| TyArr(args, ret, _) ->
List.fold_right (fun t ftvs -> StringSet.union (ftv_type t) ftvs)
args
(ftv_type ret)
| TyApp(typ, args, _) ->
List.fold_right (fun t ftvs -> StringSet.union (ftv_type t) ftvs)
args
(ftv_type typ)
;;
let ftv_scheme (s : 'a scheme) : StringSet.t =
match s with
| SForall(args, typ, _) ->
StringSet.diff (ftv_typ typ) (StringSet.of_list args)
let ftv_env (e : 'a typ envt) : StringSet.t = ...

Do Now!

You are welcome to clean this up and make it more efficient, by passing along a StringSet.t as an accumulator parameter, of all the free type variables seen so far.

#### 2.2Unification

The central operation of type inference is unification, which requests that we make two types become equal by computing a substitution for their free type variables under which the types really become identical:

let rec unify (t1 : 'a typ) (t2 : 'a typ) (* other bookkeeping *) : 'a typ subst = ...

For example, we can unify 'X -> Int and Bool -> 'Y under the substitution ['X = Bool, 'Y = Int]. If we apply this substitution to both types, they both become Bool -> Int and are identical.

However, there is no substitution that can unify Int -> 'X with Bool -> 'Y. Even if we set ['X = 'Y], we cannot unify Bool with Int. We get a unification error, which we report as a type error.

Our goal is to produce as many substitutions as necessary: each substitution asserts that a particular type variable must be equal to a particular type. This is the key behind Hindley-Milner type inference. We can always guess a new type variable, and unification will steadily constrain that guess with more and more equality constraints until we have a final answer, or a contradiction.

Note that we have to be careful in unifying too recklessly. We cannot unify 'A with 'A -> 'B, because we would get the absurd substitution ['A = 'A -> 'B]. It is immediately clear that these two types cannot be equal: the left side has “one fewer arrows” than the right side. To rule this out, we must perform an occurs check, to see whether the variable we’re trying to constrain appears within its constraint:

let occurs (name : string) (t : 'a typ) =
StringSet.mem name (ftv_type t)

Our unification proceeds by cases:

• If either t1 or t2 is a type variable, and it does not occur in the other type, then produce the substitution that binds it to the other type.

• If we have two type constants (like Int or Bool) that are equal, produce the empty substitution. Otherwise, unification fails.

• If we have two arrow types of the same arity, then unify the corresponding arguments, and unify the return types. The overall substitution that unifies the two arrows is just the composition of all the resulting substitutions together.

• If we have two type applications of the same arity, then unify the types and the corresponding type arguments. The overall substitution that unifies the two applications is just the composition of all the resulting substitutions together.

• All other cases result in a unification failure.

Exercise

In the implementation of unify, it is quite helpful to pass around diagnostic information, such as a sourcespan of the expression that triggered the unification request (and possibly more information) so that your error message can direct the programmer to where the problem occurred. Consider what information might be useful, and extend the signature and implementation of unify accordingly.

#### 2.3Inference and Instantiation

Our goal in inferring types for our program is to compute a type for all subexpressions, for which the program is entirely self-consistent. To do this, our signature will benefit from returning not only the computed type, but also the substitution that we’ve accumulated so far.2If we had monadic notation, as in Haskell, this syntactic overhead could be avoided. Our primary function for type-inference of expressions will be

let rec infer_exp (funenv : sourcespan scheme envt) (* environment for functions *)
(env : sourcespan typ envt) (* environment for variables *)
(e : sourcespan expr)
(* ... some bookkeeping ... *)
: ( sourcespan typ subst (* the resulting substitution *)
* sourcespan typ)      (* the resulting type *) =
match e with
...

Suppose we construct two base types,

let tInt = TyCon("Int", dummy_sourcespan)
let tBool = TyCon("Bool", dummy_sourcespan)

Then the first two cases of infer_exp are easy:

let rec infer_exp funenv env e ... =
match e with
| ENumber _ -> ([], tInt)
| EBool _ -> ([], tBool)
...

Variables can be handled by looking them up in env. But how should we handle things like EPrim1(Print, e, _)? After all, polymorphic operators like these were why we built this whole new infrastructure!

The key observation is that print has the type scheme Forall 'X, ('X -> 'X). Therefore, it has the type foo -> foo, for any type foo we want. So our key step is simply to create a fresh new type variable, one that’s never been used before in our program, and substitute it into the type scheme.

let gensym =
let count = ref 0 in
let next () =
count := !count + 1;
!count
in fun str -> sprintf "%s_%d" str (next ());;
let instantiate (s : 'a scheme) : 'a typ = ...

Define a function instantiate that accomplishes this specialization step. It will use the gensym function to create any necessary new type variables.3Yes, we’re using a stateful gensym instead of pure tags. Since we don’t know a priori how many times we’ll instantiate a type scheme, it’s hard to know in advance what tags to even use.

We can now use this function to infer types for prim1 and prim2 constructs. Populate the initial funenv with type schemes for all the primitives by giving them unique names. Each primitive should have some arrow type, to record its argument type(s) and return type. Then, in the EPrim1 case:

• Lookup the relevant type scheme

• Instantiate it to a type.

• Infer a type for the argument(s) of the primitive.

• Make up a brand-new type variable for the return type of the operation.

• Construct a new arrow type using the inferred types of the argument(s) and the made-up return type variable.

• Recursively unify the looked-up arrow type of the operator, with the constructed arrow type.

• If all goes well, return the newly-constructed return type variable, and the substitution obtained from the recursive unification call.

A similar process applies for function-application expressions.

For conditionals, recursively infer types for the condition and the two branches. Then unify the condition’s type with tBool, and unify the two branches’ types with each other. The resulting type is the type of a branch, and the resulting substitution is the composition of five smaller substitutions: the three subexpressions, plus the two unifcation calls:

let rec infer_exp funenv env e ... =
match e with
...
| EIf(c, t, f, _) ->
(* After each inference, update the type environment *)
let (c_subst, c_typ) = infer_exp funenv env c ... in
let env = apply_subst_env c_subst env in
let (t_subst, t_typ) = infer_exp funenv env t ... in
let env = apply_subst_env t_subst env in
let (f_subst, f_typ) = infer_exp funenv env f ... in
let env = apply_subst_env f_subst env in
(* Compose the substitutions together *)
let subst_so_far = compose_subst (compose_subst c_subst t_subst) f_subst in
(* rewrite the types *)
let c_typ = apply_subst_typ subst_so_far c_typ in
let t_typ = apply_subst_typ subst_so_far t_typ in
let f_typ = apply_subst_typ subst_so_far f_typ in
(* unify condition with Bool *)
let unif_subst1 = unify c_typ tBool ... in
(* unify two branches *)
let unif_subst2 = unify t_typ f_typ ... in
(* compose all substitutions *)
let final_subst = compose_subst (compose_subst subst_so_far unif_subst1) unif_subst2 in
let final_typ = apply_subst_typ final_subst t_typ in
(final_subst, final_typ)
| ...

Handling let-bindings requires careful management of the environment. Infer a type for the first binding. Add the variable and that type to the environment, and recursively process the remaining bindings. When no bindings remain, process the body. The resulting type is the type of the body, and the resulting substitution is the composition of all the intermediate substitutions.

#### 2.4Inference and Generalization

It stands to reason that if using a function (in a function call) instantiates its type scheme, then defining a function must somehow generalize the type of the body into a type scheme. This process is straightforward: collect all the free type variables in the type of the function body, subtract away any type variables that appear free in the type environment, and whatever is leftover can be generalized into a type scheme.

Exercise

Define a function

let generalize (env : 'a typ envt) (t : 'a typ) : 'a scheme = ...

The challenges here mostly have to do with when to generalize which function bodies. To handle this, we will collect function definitions into groups of mutually-recursive definitions. To infer type schemes for a definition group, instantiate the type schemes for the functions all at once. Infer types for each function body, and accumulate the substitutions that result. Finally, generalize all the remaining types all at once.

Exercise

Define two functions

let infer_decl funenv env (decl : sourcespan decl) (* ... bookkeeping ... *)
: (sourcespan scheme envt * sourcespan typ) = ...

let infer_group funenv env (group : sourcespan decl list) (* ... bookkeeping ...*)
: sourcespan scheme envt =

to accomplish this. In infer_decl, you should assume that the function you’re inferring has already been instantiated to a (potentially unknown) monomorphic type, while in infer_group, you should do the instantiation and generalization.

Consider the following two programs:

def f(x): # should have scheme Forall 'X, ('X -> 'X)
print(x)

def ab_bool(a, b): # should have scheme Forall 'A, 'B, ('A, 'B -> Bool)
isnum(f(a)) && f(b)

ab_bool(3, true) && ab_bool(true, false)

versus

def f(x):
print(x)

and def ab_bool(a, b):
isnum(f(a)) && f(b) # ???

ab_bool(3, true) && ab_bool(true, false)

In this version of the program, since f and ab_bool are mutually recursive, f can only have one type within the body of ab_bool, and so the line with the question marks cannot typecheck.

On the other hand, consider a simple pair of mutually recursive functions:

def even(n):
not(odd(n))

and def odd(n):
if n == 0: false
else: if n == 1: true
else:
even(n - 1)

odd(5)

If these two functions were not mutually recursive (marked by the and keyword), then even would have type Forall 'X, 'X -> Bool, since it never uses n in any particular way.

### 3Putting it all together

We’re nearly finished. Define a function

let infer_prog (funenv : sourcespan scheme envt) (env : sourcespan typ envt)
(p : sourcespan program) : sourcespan program = ...

that either returns the program if it successfully infers a type, or raises an exception if it fails. Finally, define

let type_synth (p : sourcespan program) : sourcespan program fallible =
try
Ok(infer_prog initial_env StringMap.empty p)
with e -> Error([e])
;;

And we’re done!

#### 3.1Challenges and hints

At the end of every inference call, be sure to apply the resulting substitution to the resulting type, to ensure that the types are as fully-rewritten as possible. This helps ensure that we only need to apply each substitution once to get the correct results.

Also be sure never to reuse type variables. When in doubt, gensym an extra one and unify it wherever needed. But be sure never to discard the resulting substitutions!

Be careful with the occurs check. Right now it won’t impact us, because we don’t have higher-order functions or other type constructors, but once we do, this can be the source of some very subtle bugs.

Test thoroughly! You will need to write larger programs than before, to test how unification constraints are propagated throughout your inference process.

Avail yourself of debug-printing output, at many stages of the process. You will always want more information than you have, in order to debug why something went wrong during inference. Once you’ve implemented this inference process, you should have a much greater sympathy for compiler authors who provide inscrutable unification error messages! Try to record as many bits of debugging information as possible.

1When we get to higher-order languages, this split is not quite so straightforward, and we’ll need to revisit it slightly.

3Yes, we’re using a stateful gensym instead of pure tags. Since we don’t know a priori how many times we’ll instantiate a type scheme, it’s hard to know in advance what tags to even use.