Representing pattern-matching with GADTs

by Matthias Puech

Here is a little programming pearl. I’ve been wanting to work on pattern-matching for a while now, and it seems like I will finally have this opportunity here at my new (academic) home, McGill.

Encoding some simply-typed languages with GADTs is now routine for a lot of OCaml programmers. You can even take (kind of) advantage of (some form of) convenient binding representation, like (weak) HOAS; you then use OCaml variables to denote your language’s variables. But what about pattern-matching? Patterns are possibly “deep”, i.e. they bind several variables at a time, and they don’t respect the usual discipline that a variable is bound for exactly its subterm in the AST.

It turns out that there is an adequate encoding, that relies on two simple ideas. The first is to treat variables in patterns as nameless placeholders bound by λ-abstractions on the right side of the arrow (this is how e.g. Coq encodes matches: match E₁ with (y, z) -> E₂ actually is sugar for match E₁ with (_, _) -> fun x y -> E₂); the second is to thread and accumulate type arguments in a GADT, much like we demonstrated in our printf example recently.

The ideas probably extends seamlessly to De Bruijn indices, by threading an explicit environment throughout the term. It stemmed from a discussion on LF encodings of pattern-matching with Francisco over lunch yesterday: what I will show enables also to represent adequately pattern-matching in LF, which I do not think was ever done this way before.

Let’s start with two basic type definitions:

type ('a, 'b) union = Inl of 'a | Inr of 'b
type ('a, 'b) pair = Pair of 'a * 'b

The encoding

First, I encode simply-typed λ-expressions with sums and products, in the very standard way with GADTs: I annotate every constructor by the (OCaml) type of its denotation.

type 'a exp =
  | Lam : ('a exp -> 'b exp) -> ('a -> 'b) exp
  | App : ('a -> 'b) exp * 'a exp -> 'b exp
  | Var : 'a -> 'a exp
  | Pair : 'a exp * 'b exp -> ('a, 'b) pair exp
  | Inl : 'a exp -> (('a, 'b) union) exp
  | Inr : 'b exp -> ('a, 'b) union exp
  | Unit : unit exp

At this point, I only included data type constructors, not their destructors. These are replaced by a pattern-matching construct: it takes a scrutinee of type 's, and a list of branches, each returning a value of the same type 'c.

  | Match : 's exp * ('s, 'c) branch list -> 'c exp

Now, each branch is the pair of a pattern, possibly deep, possibly containing variables, and an expression where all these variables are bound.

(* 's = type of scrutinee; 'c = type of return *)
and ('s, 'c) branch =
  | Branch : ('s, 'a, 'c exp) patt * 'a -> ('s, 'c) branch

To account for these bindings, I use a trick when defining patterns that is similar to the one used for printf with GADTs. In the type of the Branch constructor, the type 'a is an “accumulator” for all variables appearing in the pattern, eventually returning a 'c exp. For instance, annotation 'a for a pattern that binds two variables of type 'a -> 'b and 'a would be ('a -> 'b) exp -> 'a exp -> 'c exp.

Let’s define type patt. Note that it also carries and checks the annotation 's for the type of the scrutinee. The first three cases are quite easy:

(* 's = type of scrutinee;
   'a = accumulator for to bind variables;
   'c = type of return *)
and ('s, 'a, 'c) patt =
  | PUnit : (unit, 'c, 'c) patt
  | PInl : ('s, 'a, 'c) patt -> (('s, 't) union, 'a, 'c) patt
  | PInr : ('t, 'a, 'c) patt -> (('s, 't) union, 'a, 'c) patt

Now, the variable case is just a nameless dummy that “stacks up” one more argument in the “accumulator”, i.e. what will be the type of the right-hand side of the branch:

  | X : ('s, 's exp -> 'c, 'c) patt

Finally, the pair case is the only binary node. It need to thread the accumulator, to the left node, then to the right.

  | PPair : ('s, 'a, 'b) patt * ('t, 'b, 'c) patt 
    -> (('s, 't) pair, 'a, 'c) patt

Note that it is possible to swap the two sides of the pair; we would then bind variables in the opposite order on the right-hand side.

That’s the encoding. Note that it ensures only well-typing of terms, not exhaustiveness of patterns (which is another story that I would like to tell in a future post).

Examples

Here are a couple of example encoded terms: first the shallow, OCaml value, then its representation:

let ex1 : = fun x -> match x with
  | Inl x -> Inr x
  | Inr x -> Inl x

let ex1_encoded : 'a 'b. (('a, 'b) union -> ('b, 'a) union) exp =
  Lam (fun x -> Match (x, [
      Branch (PInl X, fun x -> Inr x);
      Branch (PInr X, fun x -> Inl x);
    ]))

let ex2 : 'a 'b. ((('a, 'b) union, ('a, 'b) union) pair
    -> ('a, 'b) union) =
  fun x -> match x with
    | Pair (x, Inl _) -> x
    | Pair (Inr _, x) -> x
    | Pair (_, Inr x) -> Inr x

let ex2_encoded : 'a 'b. ((('a, 'b) union, ('a, 'b) union) pair 
    -> ('a, 'b) union) exp =
  Lam (fun x -> Match (x, [
      Branch (PPair (X, PInl X), (fun x _ -> x));
      Branch (PPair (PInr X, X), (fun _ x -> x));
      Branch (PPair (X, PInr X), (fun _ x -> Inr x));
    ]))

An interpreter

Finally, we can code an evaluator for this language. It takes an expression to its (well-typed) denotation. The first few lines are standard:

let rec eval : type a. a exp -> a = function
  | Lam f -> fun x -> eval (f (Var x))
  | App (m, n) -> eval m (eval n)
  | Var x -> x
  | Pair (m, n) -> Pair (eval m, eval n)
  | Inl m -> Inl (eval m)
  | Inr m -> Inr (eval m)
  | Unit -> ()
  | Match (m, bs) -> branches (eval m) bs

Now for pattern-matching, we call an auxilliary function branches that will try each branch sequentially:

and branches : type s a. s -> (s, a) branch list -> a = fun v -> function
  | [] -> failwith "pattern-matching failure"
  | Branch (p, e) :: bs -> 
    try eval (branch e (p, v)) with Not_found -> branches v bs

A branch is tested by function branch, which is where the magic happens: it matches the pattern and the value of the scrutinee, and returns a (potentially only) partially applied resulting expression. The first cases are self-explanatory:

and branch : type s a c. a -> (s, a, c) patt * s -> c = fun e -> function
  | PUnit, () -> e
  | PInl p, Inl v -> branch e (p, v)
  | PInr p, Inr v -> branch e (p, v)
  | PInl _, Inr _ -> raise Not_found
  | PInr _, Inl _ -> raise Not_found

In the variable case, we know that e is a function that expects an argument: the value v of the scrutinee.

  | X, v -> e (Var v)

The pair case is simple and beautiful: we just compose the application of branch on both sub-patterns.

  | PPair (p, q), Pair (v, w) -> branch (branch e (p, v)) (q, w)

That’s it. Nice eh? There are two obvious questions that I leave for future posts: can we compile this encoding down to simple case statement, with the guarantee of type preservation? and could we enhance the encoding such as to guarantee statically exhaustiveness?

See you soon!