Type Inference for STLC

Type inference is to automatically calculate the type of an expression without any explicit type annotations or with partial type annotations. Type inference is also called type reconstruction in some scenarios. For an given expression in a language, type inference determines its principal type if it has, otherwise the type inference would reject this expression as a type check do.

Simply Typed Lambda Calculus (STLC) is the most simplest type system with only one type constructor: function type ($\rightarrow$). Since its simplicity and do not have any support for general recursion, type inference for STLC without any annotation is decidable. This post presents an functional implementation of a type inference algorithm of STLC in Racket. The idea is based on Chapter 16 of Programming Languages and Lambda Calculi.

We firstly defined a small language base on STLC with numbers and arithmetic operations like plus and multiplication, as well a parser which translate S-Expression to structures we defined.

#lang racket

;; Type Inference for Simply Typed Lambda Calculus
;; Guannan Wei <guannanwei@outlook.com>

(require rackunit)
(require racket/set)

;; Expressions

(struct NumE (n) #:transparent)
(struct IdE (id) #:transparent)
(struct PlusE (l r) #:transparent)
(struct MultE (l r) #:transparent)
(struct LamE (arg body) #:transparent)
(struct AppE (fun arg) #:transparent)

;; Types

(struct NumT () #:transparent)
(struct VarT (name) #:transparent)
(struct ArrowT (arg result) #:transparent)

;; Values

(struct NumV (n) #:transparent)
(struct ClosureV (arg body env) #:transparent)

;; Environment & Type Environment

(struct Binding (name val) #:transparent)
(define lookup (make-lookup 'lookup Binding? Binding-name Binding-val))
(define ext-env cons)

(struct TypeBinding (name type) #:transparent)
(define type-lookup (make-lookup 'type-lookup TypeBinding? TypeBinding-name TypeBinding-type))
(define ext-tenv cons)

;; Parsers

(define (parse s)
  (match s
    [(? number? x) (NumE x)]
    [(? symbol? x) (IdE x)]
    [`(+ ,l ,r) (PlusE (parse l) (parse r))]
    [`(* ,l ,r) (MultE (parse l) (parse r))]
    [`(let ([,var ,val]) ,body)
     (AppE (LamE var (parse body)) (parse val))]
    [`(λ (,var) ,body) (LamE var (parse body))]
    [`(,fun ,arg) (AppE (parse fun) (parse arg))]
    [else (error 'parse "invalid expression")]))

The type inference works as follows. We firstly add a special type, type variable VarT. The function type-infer determines the type of an expression and also collects the constraints on it. type-infer returns a tuple which consist with the type, the constraints and type variables the constraints set used. Then we unify all the constraints by unify and unify-helper, if success the expression has a type; otherwise, the unfication failed means the expression is not typable. But since we have added a special type type variable, the result type comes from type-infer may be is a type variable which may have an instantiation, i.e. a concrete type, so we need to refiy these information and generate the actual type of it. This work done by function reify.

;; Type Inference

(struct Pair (fst snd) #:transparent)

(define (type-subst in src dst)
  (match in
    [(NumT) in]
    [(VarT x) (if (equal? src in) dst in)]
    [(ArrowT t1 t2) (ArrowT (type-subst t1 src dst)
                            (type-subst t2 src dst))]))

(define (subst pairs src dst)
  (cond [(empty? pairs) pairs]
        [else (define p (first pairs))
              (define pf (Pair-fst p))
              (define ps (Pair-snd p))
              (cons (Pair (type-subst pf src dst)
                          (type-subst ps src dst))
                    (subst (rest pairs) src dst))]))

(define (occurs? t in)
  (match in
    [(NumT) #f]
    [(ArrowT at rt) (or (occurs? t at) (occurs? t rt))]
    [(VarT x) (equal? t in)]))

(define not-occurs? (compose not occurs?))

(define (unify-error t1 t2)
  (error 'type-error "can not unify: ~a and ~a" t1 t2))

(define (unify-helper substs result)
  (match substs
    ['() result]
    [(list (Pair fst snd) rest ...)
     (match* (fst snd)
       [((NumT) (NumT)) (unify-helper rest result)]
       [((VarT x) t)
        (if (not-occurs? fst snd)
            (unify-helper (subst rest fst snd) (cons (Pair fst snd) result))
            (unify-error fst snd))]
       [(t (VarT x))
        (if (not-occurs? snd fst)
            (unify-helper (subst rest snd fst) (cons (Pair snd fst) result))
            (unify-error snd fst))]
       [((ArrowT t1 t2) (ArrowT t3 t4))
        (unify-helper `(,(Pair t1 t3) ,(Pair t2 t4) ,@rest) result)]
       [(_ _)  (unify-error fst snd)])]))

(define (unify substs) (unify-helper (set->list substs) (list)))

(define (type-infer exp tenv const vars)
  (match exp
    [(NumE n) (values (NumT) const vars)]
    [(PlusE l r)
     (define-values (lty lconst lvars) (type-infer l tenv const vars))
     (define-values (rty rconst rvars) (type-infer r tenv lconst lvars))
     (values (NumT) (set-add (set-add rconst (Pair lty (NumT))) (Pair rty (NumT))) rvars)]
    [(MultE l r)
     (define-values (lty lconst lvars) (type-infer l tenv const vars))
     (define-values (rty rconst rvars) (type-infer r tenv lconst lvars))
     (values (NumT) (set-add (set-add rconst (Pair lty (NumT))) (Pair rty (NumT))) rvars)]
    [(IdE x)
     (values (type-lookup x tenv) const vars)]
    [(LamE arg body)
     (define new-tvar (VarT (add1 vars)))
     (define-values (bty bconst bvars)
       (type-infer body (ext-tenv (TypeBinding arg new-tvar) tenv) const (add1 vars)))
     (values (ArrowT new-tvar bty) bconst bvars)]
    [(AppE fun arg)
     (define-values (funty funconst funvars) (type-infer fun tenv const vars))
     (define-values (argty argconst argvars) (type-infer arg tenv funconst funvars))
     (define new-tvar (VarT (add1 argvars)))
     (values new-tvar (set-add (set-union funconst argconst)
                               (Pair funty (ArrowT argty new-tvar))) (add1 argvars))]))

(define (reify substs ty)
  (define (lookup/default x substs)
    (cond [(empty? substs) x]
          [(equal? (Pair-fst (first substs)) x)
           (Pair-snd (first substs))]
          [else (lookup/default x (rest substs))]))

  (match ty
    [(NumT) (NumT)]
    [(BoolT) (BoolT)]
    [(VarT x)
     (define ans (lookup/default ty substs))
     (if (ArrowT? ans) (reify substs ans) ans)]
    [(ArrowT t1 t2)
     (ArrowT (reify substs t1) (reify substs t2))]))

(define (typecheck exp tenv)
  (define-values (ty constraints vars) (type-infer exp tenv (set) 0))
  (reify (unify constraints) ty))

The type inference algorithm could work with concrete types. For example, if the program is $\lambda x. \lambda y. x + y$, then the inference algorithm figures out the type should be $int \rightarrow int$.

(check-equal? (typecheck (parse '{λ {x} {λ {y} {+ x y}}}) mt-tenv)
              (ArrowT (NumT) (ArrowT (NumT) (NumT))))

Also, it works well on polymorphism functions. The following code shows the type inference on $ \lambda f. \lambda u. u (f u) $ which has type $ ((a \rightarrow b) \rightarrow a) \rightarrow (a \rightarrow b) \rightarrow b $.

(check-equal? (typecheck (parse '{λ {f} {λ {u} {u {f u}}}}) mt-tenv)
              (ArrowT (ArrowT (ArrowT (VarT 3) (VarT 4)) (VarT 3))
                      (ArrowT (ArrowT (VarT 3) (VarT 4)) (VarT 4))))

We can also show that the Omega combinator is not typable in STLC.

(typecheck (parse '{{λ {x} {x x}} {λ {x} {x x}}}) mt-tenv)