Week 3: Techniques for Defining Recursive Functions
- Author
Nate Foster, with some materials from CS6115
- License
No redistribution allowed
And now for something completely different... We've seen how we can define some recursive functions in Coq, provided that it passes the termination checker. But what if we have a more complicated function? As an example, let's look at implementing merge sort for lists of natural numbers, using the standard ordering.
Definition nat_lte := Compare_dec.le_gt_dec.
First, we need to define a function to merge two (already sorted lists). Normally, we'd write this as follows:
Unfortunately, Coq will reject this because it's the case that xs is always getting smaller, nor the case that ys is always getting smaller. Of course, one of them is always getting smaller, so eventually, this will terminate. But in this case, we can hack around the problem by simply re-organizing the function as follows.
Fixpoint merge (xs:list nat) : list nat -> list nat :=
match xs with
| nil => fun ys => ys
| x::xs' =>
(fix inner_merge (ys:list nat) : list nat :=
match ys with
| nil => x::xs'
| y::ys' =>
if nat_lte x y then
x :: (merge xs' ys)
else
y :: (inner_merge ys')
end)
end.
We can write some examples to test that merge
works as expected.
Now let's write a pair of helper functions, one that takes a list of list of naturals and merges consecutive elements, and another that transforms a list into a list of singleton lists.
Fixpoint merge_pairs (xs:list (list nat)) : list (list nat) := match xs with | h1::h2::t => (merge h1 h2) :: (merge_pairs t) | xs' => xs' end. Definition make_lists (xs:list nat) : list (list nat) := List.map (fun x => x::nil) xs.
Here are some examples to illustrate what these helper functions do.
Next we will prove an important lemma that relate the length of the list computed by merge_pairs
to the length of its argument. Note that we prove a conjunction to get a strong enough induction hypothesis. (As a fun exercise, try to prove just the first conjunct by induction and see where you get stuck.)
forall xs : list (list nat), (forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)) /\ (forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs))forall xs : list (list nat), (forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)) /\ (forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs))(forall h1 h2 : list nat, Datatypes.length (merge_pairs [h1; h2]) < Datatypes.length [h1; h2]) /\ (forall h : list nat, Datatypes.length (merge_pairs [h]) <= Datatypes.length [h])x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)(forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: x :: xs)) < Datatypes.length (h1 :: h2 :: x :: xs)) /\ (forall h : list nat, Datatypes.length (merge_pairs (h :: x :: xs)) <= Datatypes.length (h :: x :: xs))(forall h1 h2 : list nat, Datatypes.length (merge_pairs [h1; h2]) < Datatypes.length [h1; h2]) /\ (forall h : list nat, Datatypes.length (merge_pairs [h]) <= Datatypes.length [h])split; lia.(list nat -> list nat -> 1 < 2) /\ (list nat -> 1 <= 1)x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)(forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: x :: xs)) < Datatypes.length (h1 :: h2 :: x :: xs)) /\ (forall h : list nat, Datatypes.length (merge_pairs (h :: x :: xs)) <= Datatypes.length (h :: x :: xs))x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: x :: xs)) < Datatypes.length (h1 :: h2 :: x :: xs)x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)forall h : list nat, Datatypes.length (merge_pairs (h :: x :: xs)) <= Datatypes.length (h :: x :: xs)x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: x :: xs)) < Datatypes.length (h1 :: h2 :: x :: xs)x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: Datatypes.length (merge_pairs (x :: xs)) <= Datatypes.length (x :: xs)forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: x :: xs)) < Datatypes.length (h1 :: h2 :: x :: xs)lia.x: list nat
xs: list (list nat)
H1: list nat -> list nat -> S (Datatypes.length (merge_pairs xs)) < S (S (Datatypes.length xs))
H2: Datatypes.length match xs with | [] => x :: xs | h2 :: t => merge x h2 :: merge_pairs t end <= S (Datatypes.length xs)list nat -> list nat -> S (Datatypes.length match xs with | [] => x :: xs | h2 :: t => merge x h2 :: merge_pairs t end) < S (S (S (Datatypes.length xs)))x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)forall h : list nat, Datatypes.length (merge_pairs (h :: x :: xs)) <= Datatypes.length (h :: x :: xs)x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)
x': list natDatatypes.length (merge_pairs (x' :: x :: xs)) <= Datatypes.length (x' :: x :: xs)lia. Qed.x: list nat
xs: list (list nat)
x': list nat
H1: Datatypes.length (merge_pairs (x' :: x :: xs)) < Datatypes.length (x' :: x :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)Datatypes.length (merge_pairs (x' :: x :: xs)) <= Datatypes.length (x' :: x :: xs)
Next, we can prove a corollary that captures the case we will need to define merge sort.
forall (h1 h2 : list nat) (xs : list (list nat)), Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)forall (h1 h2 : list nat) (xs : list (list nat)), Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)h1, h2: list nat
xs: list (list nat)Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)h1, h2: list nat
xs: list (list nat)(forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)) /\ (forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)) -> Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)apply H. Qed.h1, h2: list nat
xs: list (list nat)
H: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
Recursive Functions with Explicit Termination Measures
Now we are ready to define our merge sort. We'll do it by iterating merge_pairs
until we have a singleton list. Since this iteration function is not structurally recursive, but rather, defined using a measure, we first need to import a couple of libraries.
Require Import Program. Require Import Program.Wf. Require Import Recdef.
The details of how the constructs in modules work is not important for now. The main thing to know is that the Function
construct is siilar to Fixpoint
but we must give a {measure ...}
clause to tell Coq what is going down. In this case, the length of the argument list is always going down when we do a recursive call.
forall (xs : list (list nat)) (x1 : list nat) (l : list (list nat)) (x2 : list nat) (xs' : list (list nat)), l = x2 :: xs' -> xs = x1 :: x2 :: xs' -> Datatypes.length (merge_pairs (x1 :: x2 :: xs')) < Datatypes.length (x1 :: x2 :: xs')forall (xs : list (list nat)) (x1 : list nat) (l : list (list nat)) (x2 : list nat) (xs' : list (list nat)), l = x2 :: xs' -> xs = x1 :: x2 :: xs' -> Datatypes.length (merge_pairs (x1 :: x2 :: xs')) < Datatypes.length (x1 :: x2 :: xs')apply merge_pairs_length. Defined.xs: list (list nat)
x1: list nat
l: list (list nat)
x2: list nat
xs': list (list nat)
teq0: l = x2 :: xs'
teq: xs = x1 :: x2 :: xs'Datatypes.length (merge_pairs (x1 :: x2 :: xs')) < Datatypes.length (x1 :: x2 :: xs')
Note that we had to use merge_pairs_length
before Coq was willing to accept our recursive definition. It is also interesting to examine the objects that are generated by this definition.
With these pieces in hand, we can finally define our merge_sort
function.
Definition merge_sort (xs:list nat) :=
merge_iter (make_lists xs).
Let's test that is is working as expected:
Recursive Functions with Fuel (And, Type Classes!)
Another approach to dealing with tricky general recursive functions is to approximate them using "fuel." To illustrate this approach, we'll introduce Coq's type classes, which are similar to Haskell's, and develop some constructs for pretty-printing various data types.
A type class defines an interface for some type. In this case, we say that types A
that implement the Show
interface have a method named show
that will convert them to a string.
Class Show (A:Type) := {
show : A -> string
}.
Here is an instance for booleans.
Instance boolShow : Show bool := { show := fun (b:bool) => if b then "true" else "false" }.
Note that we cannot yet use this for natural numbers:
To define a Show
instance for natural numbers, let's first define a helper that shows a single digit.
Definition digit2string(d:nat) : string :=
match d with
| 0 => "0" | 1 => "1" | 2 => "2" | 3 => "3"
| 4 => "4" | 5 => "5" | 6 => "6" | 7 => "7"
| 8 => "8" | _ => "9"
end.
Alas, it can be difficult to convince Coq that the iterated version of this helper function terminates. So, we can simply give it fuel.
Fixpoint digits'(fuel n:nat) (accum : string) : string :=
match fuel with
| 0 => accum
| S fuel' =>
match n with
| 0 => accum
| _ => let d := digit2string(n mod 10) in
digits' fuel' (n / 10) (d ++ accum)
end
end.
Fortunately, it's sufficient to use n
as the fuel for itself, since we know we won't need to divide n
by 10
more than n
times. We could of course use the log
of n
(base 10
) instead, but there's no benefit to being this precise, so we'll just use n
.
Definition digits (n:nat) : string :=
match digits' n n "" with
| "" => "0"
| ds => ds
end.
Here is another version written using Function
and with an explicit termination measure.
forall n : nat, string -> forall n0 : nat, n = S n0 -> S n0 / 10 < S n0forall n : nat, string -> forall n0 : nat, n = S n0 -> S n0 / 10 < S n0apply Nat.div_lt; lia. Qed.n: nat
accum: string
n0: nat
teq: n = S n0S n0 / 10 < S n0
Now we can define the Show
instance for nat
.
Instance natShow : Show nat := { show := digits }.
Importantly, because we used a type class, we can still show
booleans.
We can also define parameterized Show
instances, which allow us to show data structures.
Instance pairShow (A B:Type) (showA : Show A) (showB : Show B) : Show (A*B) := { show := (fun p => "(" ++ (show (fst p)) ++ "," ++ (show (snd p)) ++ ")") }.Definition show_list {A:Type} {showA:Show A} (xs:list A) : string := ("[" ++ ((fix loop (xs:list A) : string := match xs with | nil => "]" | h::nil => show h ++ "]" | h::t => show h ++ "," ++ loop t end) xs)). Instance listShow (A:Type) (showA:Show A) : Show (list A) := { show := @show_list A showA }.
Type classes are a powerful abstraction tool that fit very well with some proof developments. If you decide to use them, one tip is to make type classes as small and simple as possible. For example, you probably want to separate out the operations you want to use from the properties that the implementation is expected to satisfy.