Type Inference for Simply-Typed Lambda Calculus

Type inference is to automatically calculate the type of an expression without any explicit type annotations or with partial type annotations. Type inference is also called type reconstruction in some scenarios. For an given expression in a language, type inference determines its principal type (most general) if it has, otherwise the type inference would reject this expression as a type checker does.

Simply Typed Lambda Calculus (STLC) is the most simplest type system which has only one type constructor: function type ($\rightarrow$). Since its simplicity and do not have any support for general recursion, type inference for STLC without any annotation is decidable. This post presents an functional implementation of a type inference algorithm of STLC in Racket. The idea is based on Chapter 16 of Programming Languages and Lambda Calculi.

For examples, for a term λx. (+ x 1) the type inferencer would tell you the type should be int -> int; and for the term λx.λy.x y the type inferencer gives you the result (a -> b) -> a -> b and leaves a and b uninstantiated.

We firstly defined a small language base on STLC with numbers and arithmetic operations like plus and multiplication, as well a parser which translate S-Expression to structures we defined.

#lang racket

;; Type Inference for Simply Typed Lambda Calculus
;; Guannan Wei <guannanwei@outlook.com>

(require rackunit)
(require racket/set)

;; Expressions

(struct NumE (n) #:transparent)
(struct IdE (id) #:transparent)
(struct PlusE (l r) #:transparent)
(struct MultE (l r) #:transparent)
(struct LamE (arg body) #:transparent)
(struct AppE (fun arg) #:transparent)

;; Types

(struct NumT () #:transparent)
(struct VarT (name) #:transparent)
(struct ArrowT (arg result) #:transparent)

;; Values

(struct NumV (n) #:transparent)
(struct ClosureV (arg body env) #:transparent)

;; Environment & Type Environment

(struct Binding (name val) #:transparent)
(define lookup (make-lookup 'lookup Binding? Binding-name Binding-val))
(define ext-env cons)

(struct TypeBinding (name type) #:transparent)
(define type-lookup (make-lookup 'type-lookup TypeBinding? TypeBinding-name TypeBinding-type))
(define ext-tenv cons)

;; Parsers

(define (parse s)
  (match s
    [(? number? x) (NumE x)]
    [(? symbol? x) (IdE x)]
    [`(+ ,l ,r) (PlusE (parse l) (parse r))]
    [`(* ,l ,r) (MultE (parse l) (parse r))]
    [`(let ([,var ,val]) ,body)
     (AppE (LamE var (parse body)) (parse val))]
    [`(λ (,var) ,body) (LamE var (parse body))]
    [`(,fun ,arg) (AppE (parse fun) (parse arg))]
    [else (error 'parse "invalid expression")]))

The type inference works as follows. We firstly add a special type, type variable VarT. The function type-infer determines the type of an expression and also collects the constraints on it. type-infer returns a tuple which consist with the type, the constraints and type variables the constraints set used. Then we unify all the constraints by unify and unify-helper, if success the expression has a type; otherwise, the unfication failed means the expression is not typable. But since we have added a special type type variable, the result type comes from type-infer may be is a type variable which may have an instantiation, i.e. a concrete type, so we need to refiy these information and generate the actual type of it. This work done by function reify.

;; Type Inference
(struct Eq (fst snd) #:transparent)

(define (type-subst in src dst)
  (match in
    [(NumT) in]
    [(BoolT) in]
    [(VarT x) (if (equal? src in) dst in)]
    [(ArrowT t1 t2) (ArrowT (type-subst t1 src dst)
                            (type-subst t2 src dst))]))

(define (unify/subst eqs src dst)
  (cond [(empty? eqs) eqs]
        [else (define eq (first eqs))
              (define eqfst (Eq-fst eq))
              (define eqsnd (Eq-snd eq))
              (cons (Eq (type-subst eqfst src dst)
                        (type-subst eqsnd src dst))
                    (unify/subst (rest eqs) src dst))]))

(define (occurs? t in)
  (match in
    [(NumT) #f]
    [(ArrowT at rt) (or (occurs? t at) (occurs? t rt))]
    [(VarT x) (equal? t in)]))

(define not-occurs? (compose not occurs?))

(define (unify-error t1 t2)
  (error 'type-error "can not unify: ~a and ~a" t1 t2))

(define (unify/helper substs result)
  (match substs
    ['() result]
    [(list (Eq fst snd) rest ...)
     (match* (fst snd)
       [((VarT x) t)
        (if (not-occurs? fst snd)
            (unify/helper (unify/subst rest fst snd) (cons (Eq fst snd) result))
            (unify-error fst snd))]
       [(t (VarT x))
        (if (not-occurs? snd fst)
            (unify/helper (unify/subst rest snd fst) (cons (Eq snd fst) result))
            (unify-error snd fst))]
       [((ArrowT t1 t2) (ArrowT t3 t4))
        (unify/helper `(,(Eq t1 t3) ,(Eq t2 t4) ,@rest) result)]
       [(x x) (unify/helper rest result)]
       [(_ _)  (unify-error fst snd)])]))

(define (unify substs) (unify/helper (set->list substs) (list)))

(define (type-infer exp tenv const)
  (match exp
    [(NumE n) (values (NumT) const)]
    [(BoolE b) (values (BoolT) const)]
    [(PlusE l r)
     (define-values (lty lconst) (type-infer l tenv (set)))
     (define-values (rty rconst) (type-infer r tenv (set)))
     (values (NumT)
             (set-add (set-add (set-union lconst rconst) (Eq lty (NumT))) (Eq rty (NumT))))]
    [(MultE l r)
     (define-values (lty lconst) (type-infer l tenv (set)))
     (define-values (rty rconst) (type-infer r tenv (set)))
     (values (NumT)
             (set-add (set-add (set-union lconst rconst) (Eq lty (NumT))) (Eq rty (NumT))))]
    [(IdE x)
     (values (type-lookup x tenv) const)]
    [(LamE arg body)
     (define new-tvar (VarT (fresh-n)))
     (define-values (bty bconst)
       (type-infer body (ext-tenv (TypeBinding arg new-tvar) tenv) const))
     (values (ArrowT new-tvar bty) bconst)]
    [(AppE fun arg)
     (define-values (funty funconst) (type-infer fun tenv (set)))
     (define-values (argty argconst) (type-infer arg tenv (set)))
     (define new-tvar (VarT (fresh-n)))
     (values new-tvar (set-add (set-union funconst argconst) (Eq funty (ArrowT argty new-tvar))))]))

(define (reify substs ty)
  (define (lookup/default x sts)
    (match sts
      ['() x]
      [(list (Eq fst snd) rest ...)
       (if (equal? fst x)
           (lookup/default snd substs)
           (lookup/default x rest))]))

  (match ty
    [(NumT) (NumT)]
    [(BoolT) (BoolT)]
    [(VarT x)
     (define ans (lookup/default ty substs))
     (if (ArrowT? ans) (reify substs ans) ans)]
    [(ArrowT t1 t2)
     (ArrowT (reify substs t1) (reify substs t2))]))

(define (typecheck exp tenv)
  (set! fresh-n (counter))
  (define-values (ty constraints) (type-infer exp tenv (set)))
  (reify (unify constraints) ty))

The type inference algorithm could work with concrete types. For example, if the program is $\lambda x. \lambda y. x + y$, then the inference algorithm figures out the type should be $int \rightarrow int$.

(check-equal? (typecheck (parse '{λ {x} {λ {y} {+ x y}}}) mt-tenv)
              (ArrowT (NumT) (ArrowT (NumT) (NumT))))

Also, it works well on polymorphism functions. The following code shows the type inference on $ \lambda f. \lambda u. u (f u) $ which has type $ ((a \rightarrow b) \rightarrow a) \rightarrow (a \rightarrow b) \rightarrow b $.

(check-equal? (typecheck (parse '{λ {f} {λ {u} {u {f u}}}}) mt-tenv)
              (ArrowT (ArrowT (ArrowT (VarT 3) (VarT 4)) (VarT 3))
                      (ArrowT (ArrowT (VarT 3) (VarT 4)) (VarT 4))))

We can also show that the Omega combinator is not typable in STLC.

(typecheck (parse '{{λ {x} {x x}} {λ {x} {x x}}}) mt-tenv)