Lecture 11: Type Inference
In the last lecture, we designed a system for
categorizing the expressions in our programs with types, and created an
algorithm for checking whether our programs had a self-consistent type.
However, that system ran aground in terms of usability, when it came to
primitives like isbool
or ==
, which accepted arguments of
“any type” and produced boolean answers, or worse yet, print
, which
took arguments of “any type” and produced answers of that type. In
order to handle this polymorphism, we had to resort to hacks: either explicitly
annotating the program everywhere there was ambiguity, or splitting the
primitives into separate, monomorphic versions that all behaved the same way.
Can we do better?
1 Type inference: guessing correctly, every time
The key observation we made last time was that our type checker must be syntax-directed, meaning at every expression of our program, we had enough local information to know how to recur on each subexpression, and we never had to “guess and check” our results. However, when we write down a polymorphic type and informally use it, we walk through a reasoning process that says “In this case, assume that this type variable means this particular type, then check that everything is self-consistent.” In other words, we mentally perform a kind of substitution of types for type-variables, then continue with the same type-checking algorithm as before.
But this is in fact overkill. The self-consistency requirement alone is enough to type-check our program, if we exploit it correctly. The key is to reframe the question from “Does this program have this particular type?” to “Does there exist a type for this program such that the program is self-consistent?” We may not know what that type is, yet, but we can deduce certain facts about it. This process is called type-inference, and we will follow the classic Hindley-Milner algorithm to deduce types for our programs.
The trick to getting started is to say, “Sure! Every expression can have a
type – let’s just make up a fresh type variable 'X
and claim that as
the type.” Of course, there may be constraints on 'X
based on how we
use the value in our program; these are called type constraints, and
type inference is simply the process of collecting and solving these
constraints. The easiest way to see this is by example.
In order to make our type inference tractable, we are going to split our notion of types and type schemes. Intuitively, a type is a single thing, while a type scheme can quantify over type variables. Our goal will be to infer a type scheme for each function in our program, and a type for each expression.1When we get to higher-order languages, this split is not quite so straightforward, and we’ll need to revisit it slightly.
1.1 Example 1
Suppose we have the program
def f(x):
x + 6
f(38)
Let’s start by inferring a type for f
. We have no annotations, so we
simply create new type variables as needed, and assert that f
has type
'T1 -> 'T2
, for two unknown type variables 'T1
and 'T2
. We begin to
infer a type for the body of f
by binding the arguments to their types
in our type environment – here, x : 'T1
.
We recur into the body of f
, and examine x + 6
. This is an
EPrim2
expression, so we look up the type scheme for Plus
, and see
that it is Forall [], (Int, Int -> Int)
. We look at the arguments to the
operator, and infer their types. We look up x
in the environment and
obtain 'T1
, and determine that all numbers have type Int
. We know
that this operation will produce some result type, so we make up a new type
variable 'prim2result
. We can now construct a type
('T1, Int -> 'prim2result)
, and unify it with the type (Int, Int -> Int)
. This
produces a substitution ['T1 = Int, 'prim2result = Int]
We apply that
substitution, and deduce that x : Int
. The result of our expression
has type 'prim2result
, and since it is the entirety of the body of
f
, we unify that with the result type 'T2
, making sure to preserve
the substitution where ['prim2result = Int]
. Our net result is
f : Int -> Int
, which we can generalize to the type scheme
Forall [], Int -> Int
, and update our function environment accordingly.
We next need to infer types for f(38)
. We repeat a similar process:
we lookup the type scheme for f
, we infer a type for 38
, and
make up a new result type 'app_result
, and unify Int -> Int
(which is
the type of f
) with Int -> 'app_result
. We deduce that
['app_result = Int]
, and this is the final type of our program.
1.2 Example 2
Suppose we have the following trickier program:
def f(x, y):
isnum(print(x)) && isbool(y)
def g(z):
f(z, 5)
g(7)
Since our functions are not mutually recursive, we can handle them independently.
Start by guessing type variables for
f
, namely('T1, 'T2 -> 'T3)
.Next, we bind the arguments in our environment,
[x : 'T1, y : 'T2]
, and recur into the body.Again we encounter a
EPrim2
, so we lookup the type scheme for the operator and get&& : Forall [], (Bool, Bool -> Bool)
. We then infer types for both arguments:Our first argument is a
EPrim1
, so we look up the type scheme for the operator and findisnum : Forall ['X], ('X -> Bool)
. We cannot use this directly, so we instantiate the type by creating new type variables and substituting them:isnum : ('T4 -> Bool)
. We recur on the argument:Our argument is another
EPrim1
, so we look up the type scheme for the operator and findprint : Forall ['X], ('X -> 'X)
. Again we instantiate the type to getprint : ('T5 -> 'T5)
. Then we recur on the argument, and look it up to findx : 'T1
. We create a return type'prim1result1
, and unify'T1 -> 'prim1result1
with'T5 -> 'T5
, to obtain the substitution['T1 = 'T5, 'prim1result1 = 'T5]
. Our result type is'prim1result1
.We create a new result type for this primitive, and unify
'prim1result1 -> 'prim1result2
with'T4 -> Bool
, to obtain the substitution['prim1result1 = 'T4, 'prim1result2 = Bool]
. We combine that with the existing substitution to get['T1 = 'T5, 'prim1result1 = 'T5, 'T4 = 'T5, 'prim1result2 = Bool]
, and return a result of'prim1result2
.
Our second argument is another
EPrim1
, and the type inference for this is analogous to that forisnum
. We look up the type scheme, instantiate it toisbool : 'T6 -> Bool
, and infer a type for the nested expression. We look that expression up to obtainy : 'T2
, generate a new result type'prim1result3
, and unify'T2 -> 'prim1result3
with'T6 -> Bool
, to obtain['T2 = 'T6, 'prim1result3 = Bool]
.We instantiate the type of the operator to
(Bool, Bool -> Bool)
. We create a result type'prim2result1
, and unify the inferred type('prim1result1, 'prim1result3 -> 'prim2result1)
with(Bool, Bool -> Bool)
, to get['prim1result1 = Bool, 'prim1result3 = Bool, 'prim2result1 = Bool]
. We combine this with all our prior substitutions, and return the result type'prim2result1
.
We unify
'prim2result1
with'T3
, then apply our overall substitution, and we obtain['T1 = 'T5, 'T2 = 'T6, 'T3 = Bool]
(along with other substitutions). Finally, we generalize this type, to get the final type schemef : Forall ['T5, 'T6], ('T5, 'T6 -> Bool)
.
We can repeat the process for g
, guessing an initial type of
'T7 -> 'T8
. When we get to the function call, we
instantiate the type for f
to f : ('T9, 'T10 -> Bool)
, and
proceed. We wind up unifying ['T7 = 'T9, 'T10 = Int, 'T8 = Bool]
, and
generalizing, we get a final type scheme of
g : Forall ['T9], ('T9 -> Bool)
.
Finally we proceed to g(7)
. Again we instantiate the scheme for
g
, unify the result with Int
, and obtain a final type of
Bool
. Our progam is self-consistent.
2 Type inference implementation
2.1 Preparation
To implement type inference ourselves, we’ll need lots of preparatory definitions:
Exercise
Do these.
Define a type environment as a mapping from names to types, and a type scheme environment as a mapping from names to type schemes. Our initial type scheme environment will contain type schemes for all our primitives. Our initial type environment will be empty. We will use ML’s Maps, rather than association lists, because they are more efficient and our environments are unordered.
type 'a envt = 'a StringMap.t
Define the notion of substituting a type for a type variable, within another type. This converts each occurrence of the given type variable
tyvar
into the desired typeto_typ
in the target typein_typ
.let rec subst_var_typ ((tyvar : string), (to_typ : 'a typ)) (in_typ : 'a typ) : 'a typ = ...
Define the notion of substituting a type for a type variable, within a type scheme. This is the same idea as above, except that if the type variable is quantified by the scheme, we do not substitute. (This is exactly analogous to capture-avoiding substitution of expressions for variables, underneath lambda or let bindings.)
let rec subst_var_scheme ((tyvar : string), (to_typ : 'a typ)) (in_scheme : 'a scheme) : 'a scheme = ...
Define a type substitution as an ordered mapping from names to types. (We use lists here to ensure an ordering, if it is needed.)
type 'a subst = (string * 'a) list
Define a related set of functions that apply a substitution to a type, a type scheme, or environments, by applying each individual substitution from left-to-right. You may not need all of these functions, but they’re good to have available.
let apply_subst_typ (subst : 'a typ subst) (t : 'a typ) : 'a typ = ... let apply_subst_scheme (subst : 'a typ subst) (s : 'a scheme) : 'a scheme = ... let apply_subst_typenv (subst : 'a typ subst) (env : 'a typ envt) : 'a typ envt = ... let apply_subst_schenv (subst : 'a typ subst) (env : 'a scheme envt) : 'a scheme envt = ... let apply_subst_substs (subst : 'a typ subst) (sub : 'a typ subst) : 'a typ subst = ...
Define a function to combine two substitutions, by first applying one to the other, then concatenating the two.
let compose_subst (sub1 : 'a typ subst) (sub2 : 'a typ subst) : 'a typ subst = sub1 @ (apply_subst_subst sub1 sub2)
Define the free type variables of a type as the set of any type variables that appear within it. Define the free type variables of a type scheme as the type variables that appear un-quantified within the scheme. Finally, define the free type variables of a type environment as simply the union of all the type variables in all the types within it.
let rec ftv_type (t : 'a typ) : StringSet.t = match t with | TyCon _ -> StringSet.empty | TyVar(name, _) -> StringSet.singleton name | TyArr(args, ret, _) -> List.fold_right (fun t ftvs -> StringSet.union (ftv_type t) ftvs) args (ftv_type ret) | TyApp(typ, args, _) -> List.fold_right (fun t ftvs -> StringSet.union (ftv_type t) ftvs) args (ftv_type typ) ;; let ftv_scheme (s : 'a scheme) : StringSet.t = match s with | SForall(args, typ, _) -> StringSet.diff (ftv_typ typ) (StringSet.of_list args) let ftv_env (e : 'a typ envt) : StringSet.t = ...
Do Now!
You are welcome to clean this up and make it more efficient, by passing along a
StringSet.t
as an accumulator parameter, of all the free type variables seen so far.
2.2 Unification
The central operation of type inference is unification, which requests that we make two types become equal by computing a substitution for their free type variables under which the types really become identical:
let rec unify (t1 : 'a typ) (t2 : 'a typ) (* other bookkeeping *) : 'a typ subst = ...
For example, we can unify 'X -> Int
and Bool -> 'Y
under the
substitution ['X = Bool, 'Y = Int]
. If we apply this substitution to both
types, they both become Bool -> Int
and are identical.
However, there is no substitution that can unify Int -> 'X
with
Bool -> 'Y
. Even if we set ['X = 'Y]
, we cannot unify
Bool
with Int
. We get a unification error, which we report
as a type error.
Our goal is to produce as many substitutions as necessary: each substitution asserts that a particular type variable must be equal to a particular type. This is the key behind Hindley-Milner type inference. We can always guess a new type variable, and unification will steadily constrain that guess with more and more equality constraints until we have a final answer, or a contradiction.
Note that we have to be careful in unifying too recklessly. We cannot unify
'A
with 'A -> 'B
, because we would get the absurd
substitution ['A = 'A -> 'B]
. It is immediately clear that these two
types cannot be equal: the left side has “one fewer arrows” than the right
side. To rule this out, we must perform an occurs check, to see whether
the variable we’re trying to constrain appears within its constraint:
let occurs (name : string) (t : 'a typ) =
StringSet.mem name (ftv_type t)
Our unification proceeds by cases:
If either
t1
ort2
is a type variable, and it does not occur in the other type, then produce the substitution that binds it to the other type.If we have two type constants (like
Int
orBool
) that are equal, produce the empty substitution. Otherwise, unification fails.If we have two arrow types of the same arity, then unify the corresponding arguments, and unify the return types. The overall substitution that unifies the two arrows is just the composition of all the resulting substitutions together.
If we have two type applications of the same arity, then unify the types and the corresponding type arguments. The overall substitution that unifies the two applications is just the composition of all the resulting substitutions together.
All other cases result in a unification failure.
Exercise
In the implementation of
unify
, it is quite helpful to pass around diagnostic information, such as asourcespan
of the expression that triggered the unification request (and possibly more information) so that your error message can direct the programmer to where the problem occurred. Consider what information might be useful, and extend the signature and implementation ofunify
accordingly.
2.3 Inference and Instantiation
Our goal in inferring types for our program is to compute a type for all subexpressions, for which the program is entirely self-consistent. To do this, our signature will benefit from returning not only the computed type, but also the substitution that we’ve accumulated so far.2If we had monadic notation, as in Haskell, this syntactic overhead could be avoided. Our primary function for type-inference of expressions will be
let rec infer_exp (funenv : sourcespan scheme envt) (* environment for functions *)
(env : sourcespan typ envt) (* environment for variables *)
(e : sourcespan expr)
(* ... some bookkeeping ... *)
: ( sourcespan typ subst (* the resulting substitution *)
* sourcespan typ) (* the resulting type *) =
match e with
...
Suppose we construct two base types,
let tInt = TyCon("Int", dummy_sourcespan)
let tBool = TyCon("Bool", dummy_sourcespan)
Then the first two cases of infer_exp
are easy:
let rec infer_exp funenv env e ... =
match e with
| ENumber _ -> ([], tInt)
| EBool _ -> ([], tBool)
...
Variables can be handled by looking them up in env
. But how should we
handle things like EPrim1(Print, e, _)
? After all, polymorphic operators
like these were why we built this whole new infrastructure!
The key observation is that print
has the type scheme
Forall 'X, ('X -> 'X)
. Therefore, it has the type
foo -> foo
, for any type foo
we want. So our key step
is simply to create a fresh new type variable, one that’s never been
used before in our program, and substitute it into the type scheme.
let gensym =
let count = ref 0 in
let next () =
count := !count + 1;
!count
in fun str -> sprintf "%s_%d" str (next ());;
let instantiate (s : 'a scheme) : 'a typ = ...
Define a function instantiate
that accomplishes this specialization step.
It will use the gensym
function to create any necessary new type
variables.3Yes, we’re using a stateful gensym
instead of pure tags.
Since we don’t know a priori how many times we’ll instantiate a type
scheme, it’s hard to know in advance what tags to even use.
We can now use this function to infer types for prim1
and prim2
constructs. Populate the initial funenv
with type schemes for all the
primitives by giving them unique names. Each primitive should have some arrow
type, to record its argument type(s) and return type. Then, in the EPrim1
case:
Lookup the relevant type scheme
Instantiate it to a type.
Infer a type for the argument(s) of the primitive.
Make up a brand-new type variable for the return type of the operation.
Construct a new arrow type using the inferred types of the argument(s) and the made-up return type variable.
Recursively unify the looked-up arrow type of the operator, with the constructed arrow type.
If all goes well, return the newly-constructed return type variable, and the substitution obtained from the recursive unification call.
A similar process applies for function-application expressions.
For conditionals, recursively infer types for the condition and the two
branches. Then unify the condition’s type with tBool
, and unify the two
branches’ types with each other. The resulting type is the type of a branch,
and the resulting substitution is the composition of five smaller
substitutions: the three subexpressions, plus the two unifcation calls:
let rec infer_exp funenv env e ... =
match e with
...
| EIf(c, t, f, _) ->
(* After each inference, update the type environment *)
let (c_subst, c_typ) = infer_exp funenv env c ... in
let env = apply_subst_env c_subst env in
let (t_subst, t_typ) = infer_exp funenv env t ... in
let env = apply_subst_env t_subst env in
let (f_subst, f_typ) = infer_exp funenv env f ... in
let env = apply_subst_env f_subst env in
(* Compose the substitutions together *)
let subst_so_far = compose_subst (compose_subst c_subst t_subst) f_subst in
(* rewrite the types *)
let c_typ = apply_subst_typ subst_so_far c_typ in
let t_typ = apply_subst_typ subst_so_far t_typ in
let f_typ = apply_subst_typ subst_so_far f_typ in
(* unify condition with Bool *)
let unif_subst1 = unify c_typ tBool ... in
(* unify two branches *)
let unif_subst2 = unify t_typ f_typ ... in
(* compose all substitutions *)
let final_subst = compose_subst (compose_subst subst_so_far unif_subst1) unif_subst2 in
let final_typ = apply_subst_typ final_subst t_typ in
(final_subst, final_typ)
| ...
Handling let-bindings requires careful management of the environment. Infer a type for the first binding. Add the variable and that type to the environment, and recursively process the remaining bindings. When no bindings remain, process the body. The resulting type is the type of the body, and the resulting substitution is the composition of all the intermediate substitutions.
2.4 Inference and Generalization
It stands to reason that if using a function (in a function call) instantiates its type scheme, then defining a function must somehow generalize the type of the body into a type scheme. This process is straightforward: collect all the free type variables in the type of the function body, subtract away any type variables that appear free in the type environment, and whatever is leftover can be generalized into a type scheme.
Exercise
Define a function
let generalize (env : 'a typ envt) (t : 'a typ) : 'a scheme = ...
that accomplishes this task.
The challenges here mostly have to do with when to generalize which function bodies. To handle this, we will collect function definitions into groups of mutually-recursive definitions. To infer type schemes for a definition group, instantiate the type schemes for the functions all at once. Infer types for each function body, and accumulate the substitutions that result. Finally, generalize all the remaining types all at once.
Exercise
Define two functions
let infer_decl funenv env (decl : sourcespan decl) (* ... bookkeeping ... *) : (sourcespan scheme envt * sourcespan typ) = ... let infer_group funenv env (group : sourcespan decl list) (* ... bookkeeping ...*) : sourcespan scheme envt =
to accomplish this. In
infer_decl
, you should assume that the function you’re inferring has already been instantiated to a (potentially unknown) monomorphic type, while ininfer_group
, you should do the instantiation and generalization.
Consider the following two programs:
def f(x): # should have scheme Forall 'X, ('X -> 'X)
print(x)
def ab_bool(a, b): # should have scheme Forall 'A, 'B, ('A, 'B -> Bool)
isnum(f(a)) && f(b)
ab_bool(3, true) && ab_bool(true, false)
versus
def f(x):
print(x)
and def ab_bool(a, b):
isnum(f(a)) && f(b) # ???
ab_bool(3, true) && ab_bool(true, false)
In this version of the program, since f
and ab_bool
are
mutually recursive, f
can only have one type within the body of
ab_bool
, and so the line with the question marks cannot typecheck.
On the other hand, consider a simple pair of mutually recursive functions:
def even(n):
not(odd(n))
and def odd(n):
if n == 0: false
else: if n == 1: true
else:
even(n - 1)
odd(5)
If these two functions were not mutually recursive (marked by the and
keyword), then even
would have type Forall 'X, 'X -> Bool
,
since it never uses n
in any particular way.
3 Putting it all together
We’re nearly finished. Define a function
let infer_prog (funenv : sourcespan scheme envt) (env : sourcespan typ envt)
(p : sourcespan program) : sourcespan program = ...
that either returns the program if it successfully infers a type, or raises an exception if it fails. Finally, define
let type_synth (p : sourcespan program) : sourcespan program fallible =
try
Ok(infer_prog initial_env StringMap.empty p)
with e -> Error([e])
;;
And we’re done!
3.1 Challenges and hints
At the end of every inference call, be sure to apply the resulting substitution to the resulting type, to ensure that the types are as fully-rewritten as possible. This helps ensure that we only need to apply each substitution once to get the correct results.
Also be sure never to reuse type variables. When in doubt, gensym an extra one and unify it wherever needed. But be sure never to discard the resulting substitutions!
Be careful with the occurs check. Right now it won’t impact us, because we don’t have higher-order functions or other type constructors, but once we do, this can be the source of some very subtle bugs.
Test thoroughly! You will need to write larger programs than before, to test how unification constraints are propagated throughout your inference process.
Avail yourself of debug-printing output, at many stages of the process. You will always want more information than you have, in order to debug why something went wrong during inference. Once you’ve implemented this inference process, you should have a much greater sympathy for compiler authors who provide inscrutable unification error messages! Try to record as many bits of debugging information as possible.
1When we get to higher-order languages, this split is not quite so straightforward, and we’ll need to revisit it slightly.
2If we had monadic notation, as in Haskell, this syntactic overhead could be avoided.
3Yes, we’re using a stateful gensym
instead of pure tags.
Since we don’t know a priori how many times we’ll instantiate a type
scheme, it’s hard to know in advance what tags to even use.