Representing pattern-matching with GADTs
by Matthias Puech
Here is a little programming pearl. I’ve been wanting to work on pattern-matching for a while now, and it seems like I will finally have this opportunity here at my new (academic) home, McGill.
Encoding some simply-typed languages with GADTs is now routine for a lot of OCaml programmers. You can even take (kind of) advantage of (some form of) convenient binding representation, like (weak) HOAS; you then use OCaml variables to denote your language’s variables. But what about pattern-matching? Patterns are possibly “deep”, i.e. they bind several variables at a time, and they don’t respect the usual discipline that a variable is bound for exactly its subterm in the AST.
It turns out that there is an adequate encoding, that relies on two simple ideas. The first is to treat variables in patterns as nameless placeholders bound by λ-abstractions on the right side of the arrow (this is how e.g. Coq encodes matches: match E₁ with (y, z) -> E₂
actually is sugar for match E₁ with (_, _) -> fun x y -> E₂
); the second is to thread and accumulate type arguments in a GADT, much like we demonstrated in our printf
example recently.
The ideas probably extends seamlessly to De Bruijn indices, by threading an explicit environment throughout the term. It stemmed from a discussion on LF encodings of pattern-matching with Francisco over lunch yesterday: what I will show enables also to represent adequately pattern-matching in LF, which I do not think was ever done this way before.
Let’s start with two basic type definitions:
type ('a, 'b) union = Inl of 'a | Inr of 'b type ('a, 'b) pair = Pair of 'a * 'b
The encoding
First, I encode simply-typed λ-expressions with sums and products, in the very standard way with GADTs: I annotate every constructor by the (OCaml) type of its denotation.
type 'a exp = | Lam : ('a exp -> 'b exp) -> ('a -> 'b) exp | App : ('a -> 'b) exp * 'a exp -> 'b exp | Var : 'a -> 'a exp | Pair : 'a exp * 'b exp -> ('a, 'b) pair exp | Inl : 'a exp -> (('a, 'b) union) exp | Inr : 'b exp -> ('a, 'b) union exp | Unit : unit exp
At this point, I only included data type constructors, not their destructors. These are replaced by a pattern-matching construct: it takes a scrutinee of type 's
, and a list of branches, each returning a value of the same type 'c
.
| Match : 's exp * ('s, 'c) branch list -> 'c exp
Now, each branch is the pair of a pattern, possibly deep, possibly containing variables, and an expression where all these variables are bound.
(* 's = type of scrutinee; 'c = type of return *) and ('s, 'c) branch = | Branch : ('s, 'a, 'c exp) patt * 'a -> ('s, 'c) branch
To account for these bindings, I use a trick when defining patterns that is similar to the one used for printf with GADTs. In the type of the Branch
constructor, the type 'a
is an “accumulator” for all variables appearing in the pattern, eventually returning a 'c exp
. For instance, annotation 'a
for a pattern that binds two variables of type 'a -> 'b
and 'a
would be ('a -> 'b) exp -> 'a exp -> 'c exp
.
Let’s define type patt
. Note that it also carries and checks the annotation 's
for the type of the scrutinee. The first three cases are quite easy:
(* 's = type of scrutinee; 'a = accumulator for to bind variables; 'c = type of return *) and ('s, 'a, 'c) patt = | PUnit : (unit, 'c, 'c) patt | PInl : ('s, 'a, 'c) patt -> (('s, 't) union, 'a, 'c) patt | PInr : ('t, 'a, 'c) patt -> (('s, 't) union, 'a, 'c) patt
Now, the variable case is just a nameless dummy that “stacks up” one more argument in the “accumulator”, i.e. what will be the type of the right-hand side of the branch:
| X : ('s, 's exp -> 'c, 'c) patt
Finally, the pair case is the only binary node. It need to thread the accumulator, to the left node, then to the right.
| PPair : ('s, 'a, 'b) patt * ('t, 'b, 'c) patt -> (('s, 't) pair, 'a, 'c) patt
Note that it is possible to swap the two sides of the pair; we would then bind variables in the opposite order on the right-hand side.
That’s the encoding. Note that it ensures only well-typing of terms, not exhaustiveness of patterns (which is another story that I would like to tell in a future post).
Examples
Here are a couple of example encoded terms: first the shallow, OCaml value, then its representation:
let ex1 : = fun x -> match x with | Inl x -> Inr x | Inr x -> Inl x let ex1_encoded : 'a 'b. (('a, 'b) union -> ('b, 'a) union) exp = Lam (fun x -> Match (x, [ Branch (PInl X, fun x -> Inr x); Branch (PInr X, fun x -> Inl x); ])) let ex2 : 'a 'b. ((('a, 'b) union, ('a, 'b) union) pair -> ('a, 'b) union) = fun x -> match x with | Pair (x, Inl _) -> x | Pair (Inr _, x) -> x | Pair (_, Inr x) -> Inr x let ex2_encoded : 'a 'b. ((('a, 'b) union, ('a, 'b) union) pair -> ('a, 'b) union) exp = Lam (fun x -> Match (x, [ Branch (PPair (X, PInl X), (fun x _ -> x)); Branch (PPair (PInr X, X), (fun _ x -> x)); Branch (PPair (X, PInr X), (fun _ x -> Inr x)); ]))
An interpreter
Finally, we can code an evaluator for this language. It takes an expression to its (well-typed) denotation. The first few lines are standard:
let rec eval : type a. a exp -> a = function | Lam f -> fun x -> eval (f (Var x)) | App (m, n) -> eval m (eval n) | Var x -> x | Pair (m, n) -> Pair (eval m, eval n) | Inl m -> Inl (eval m) | Inr m -> Inr (eval m) | Unit -> () | Match (m, bs) -> branches (eval m) bs
Now for pattern-matching, we call an auxilliary function branches
that will try each branch sequentially:
and branches : type s a. s -> (s, a) branch list -> a = fun v -> function | [] -> failwith "pattern-matching failure" | Branch (p, e) :: bs -> try eval (branch e (p, v)) with Not_found -> branches v bs
A branch is tested by function branch
, which is where the magic happens: it matches the pattern and the value of the scrutinee, and returns a (potentially only) partially applied resulting expression. The first cases are self-explanatory:
and branch : type s a c. a -> (s, a, c) patt * s -> c = fun e -> function | PUnit, () -> e | PInl p, Inl v -> branch e (p, v) | PInr p, Inr v -> branch e (p, v) | PInl _, Inr _ -> raise Not_found | PInr _, Inl _ -> raise Not_found
In the variable case, we know that e
is a function that expects an argument: the value v
of the scrutinee.
| X, v -> e (Var v)
The pair case is simple and beautiful: we just compose the application of branch
on both sub-patterns.
| PPair (p, q), Pair (v, w) -> branch (branch e (p, v)) (q, w)
That’s it. Nice eh? There are two obvious questions that I leave for future posts: can we compile this encoding down to simple case statement, with the guarantee of type preservation? and could we enhance the encoding such as to guarantee statically exhaustiveness?
See you soon!
Very nice, indeed!
It would be nice to extend this with union and intersection of patterns (as in the calculus of Neelk in “Focusing on Pattern Matching”, the latter one with a notion of dynamic failure).
Hi Gabriel, thank you for your suggestion!
Extending this with union (POr) and intersection (PAnd) is actually not that hard, considering their operational content: a POr first tries the left branch, and falls back on the right one if it fails; they should bind the same variables; a PAnd is like matching on a pair, except that the value is twice the same. This gets us the following cases for the evaluation:
| PAnd (p, q), v -> branch (branch e (p, v)) (q, v)
| POr (p, q), v -> try branch e (p, v) with Not_found -> branch e (q, v)
which suggests the following constructor definitions:
| POr : (‘s, ‘a, ‘c) patt * (‘s, ‘a, ‘c) patt -> (‘s, ‘a, ‘c) patt
| PAnd : (‘s, ‘a, ‘b) patt * (‘s, ‘b, ‘c) patt -> (‘s, ‘a, ‘c) patt
Are you convinced?
Oh, I had no doubt that it could done, but I think it does make the presentation (even) more convincing.
[…] the recursion scheme programming pattern? Also, try to encode the typed λ-calculus this way, using GADTs; you will need type x to be parametric on a type ‘a, therefore encoding rank-2 polymorphism. […]