Systems and
Formalisms Lab

Week 3: Techniques for Defining Recursive Functions

Author

Nate Foster, with some materials from CS6115

License

No redistribution allowed

And now for something completely different... We've seen how we can define some recursive functions in Coq, provided that it passes the termination checker. But what if we have a more complicated function? As an example, let's look at implementing merge sort for lists of natural numbers, using the standard ordering.

Definition nat_lte := Compare_dec.le_gt_dec.

First, we need to define a function to merge two (already sorted lists). Normally, we'd write this as follows:

Recursive definition of merge is ill-formed. In environment merge : list nat -> list nat -> list nat xs : list nat ys : list nat x : nat xs' : list nat y : nat ys' : list nat g : x > y Recursive call to merge has principal argument equal to "xs" instead of "xs'". Recursive definition is: "fun xs ys : list nat => match xs with | [] => ys | x :: xs' => match ys with | [] => xs | y :: ys' => if nat_lte x y then x :: merge xs' ys else y :: merge xs ys' end end".

Unfortunately, Coq will reject this because it's the case that xs is always getting smaller, nor the case that ys is always getting smaller. Of course, one of them is always getting smaller, so eventually, this will terminate. But in this case, we can hack around the problem by simply re-organizing the function as follows.

Fixpoint merge (xs:list nat) : list nat -> list nat :=
  match xs with
  | nil => fun ys => ys
  | x::xs' =>
      (fix inner_merge (ys:list nat) : list nat :=
         match ys with
         | nil => x::xs'
         | y::ys' =>
             if nat_lte x y then
               x :: (merge xs' ys)
             else
               y :: (inner_merge ys')
         end)
  end.

We can write some examples to test that merge works as expected.

= [1; 2; 3; 4; 5; 6] : list nat
= [1; 3; 4] : list nat

Now let's write a pair of helper functions, one that takes a list of list of naturals and merges consecutive elements, and another that transforms a list into a list of singleton lists.

Fixpoint merge_pairs (xs:list (list nat)) : list (list nat) :=
  match xs with
  | h1::h2::t => (merge h1 h2) :: (merge_pairs t)
  | xs' => xs'
  end.

Definition make_lists (xs:list nat) : list (list nat) :=
  List.map (fun x => x::nil) xs.

Here are some examples to illustrate what these helper functions do.

= [[5]; [1]; [4]; [2]; [3]] : list (list nat)
= [[1; 2; 3; 4; 9]; [0; 2; 3]] : list (list nat)
= [[1; 2; 3; 4; 9]; [0]] : list (list nat)
= [[1; 2; 3; 4; 5]] : list (list nat)

Next we will prove an important lemma that relate the length of the list computed by merge_pairs to the length of its argument. Note that we prove a conjunction to get a strong enough induction hypothesis. (As a fun exercise, try to prove just the first conjunct by induction and see where you get stuck.)


forall xs : list (list nat), (forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)) /\ (forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs))

forall xs : list (list nat), (forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)) /\ (forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs))

(forall h1 h2 : list nat, Datatypes.length (merge_pairs [h1; h2]) < Datatypes.length [h1; h2]) /\ (forall h : list nat, Datatypes.length (merge_pairs [h]) <= Datatypes.length [h])
x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)
(forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: x :: xs)) < Datatypes.length (h1 :: h2 :: x :: xs)) /\ (forall h : list nat, Datatypes.length (merge_pairs (h :: x :: xs)) <= Datatypes.length (h :: x :: xs))

(forall h1 h2 : list nat, Datatypes.length (merge_pairs [h1; h2]) < Datatypes.length [h1; h2]) /\ (forall h : list nat, Datatypes.length (merge_pairs [h]) <= Datatypes.length [h])

(list nat -> list nat -> 1 < 2) /\ (list nat -> 1 <= 1)
split; lia.
x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)

(forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: x :: xs)) < Datatypes.length (h1 :: h2 :: x :: xs)) /\ (forall h : list nat, Datatypes.length (merge_pairs (h :: x :: xs)) <= Datatypes.length (h :: x :: xs))
x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)

forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: x :: xs)) < Datatypes.length (h1 :: h2 :: x :: xs)
x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)
forall h : list nat, Datatypes.length (merge_pairs (h :: x :: xs)) <= Datatypes.length (h :: x :: xs)
x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)

forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: x :: xs)) < Datatypes.length (h1 :: h2 :: x :: xs)
x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: Datatypes.length (merge_pairs (x :: xs)) <= Datatypes.length (x :: xs)

forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: x :: xs)) < Datatypes.length (h1 :: h2 :: x :: xs)
x: list nat
xs: list (list nat)
H1: list nat -> list nat -> S (Datatypes.length (merge_pairs xs)) < S (S (Datatypes.length xs))
H2: Datatypes.length match xs with | [] => x :: xs | h2 :: t => merge x h2 :: merge_pairs t end <= S (Datatypes.length xs)

list nat -> list nat -> S (Datatypes.length match xs with | [] => x :: xs | h2 :: t => merge x h2 :: merge_pairs t end) < S (S (S (Datatypes.length xs)))
lia.
x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)

forall h : list nat, Datatypes.length (merge_pairs (h :: x :: xs)) <= Datatypes.length (h :: x :: xs)
x: list nat
xs: list (list nat)
H1: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)
x': list nat

Datatypes.length (merge_pairs (x' :: x :: xs)) <= Datatypes.length (x' :: x :: xs)
x: list nat
xs: list (list nat)
x': list nat
H1: Datatypes.length (merge_pairs (x' :: x :: xs)) < Datatypes.length (x' :: x :: xs)
H2: forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)

Datatypes.length (merge_pairs (x' :: x :: xs)) <= Datatypes.length (x' :: x :: xs)
lia. Qed.

Next, we can prove a corollary that captures the case we will need to define merge sort.


forall (h1 h2 : list nat) (xs : list (list nat)), Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)

forall (h1 h2 : list nat) (xs : list (list nat)), Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
h1, h2: list nat
xs: list (list nat)

Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
h1, h2: list nat
xs: list (list nat)

(forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)) /\ (forall h : list nat, Datatypes.length (merge_pairs (h :: xs)) <= Datatypes.length (h :: xs)) -> Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
h1, h2: list nat
xs: list (list nat)
H: forall h1 h2 : list nat, Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)

Datatypes.length (merge_pairs (h1 :: h2 :: xs)) < Datatypes.length (h1 :: h2 :: xs)
apply H. Qed.

Recursive Functions with Explicit Termination Measures

Now we are ready to define our merge sort. We'll do it by iterating merge_pairs until we have a singleton list. Since this iteration function is not structurally recursive, but rather, defined using a measure, we first need to import a couple of libraries.

Require Import Program.
Require Import Program.Wf.
Require Import Recdef.

The details of how the constructs in modules work is not important for now. The main thing to know is that the Function construct is siilar to Fixpoint but we must give a {measure ...} clause to tell Coq what is going down. In this case, the length of the argument list is always going down when we do a recursive call.


forall (xs : list (list nat)) (x1 : list nat) (l : list (list nat)) (x2 : list nat) (xs' : list (list nat)), l = x2 :: xs' -> xs = x1 :: x2 :: xs' -> Datatypes.length (merge_pairs (x1 :: x2 :: xs')) < Datatypes.length (x1 :: x2 :: xs')

forall (xs : list (list nat)) (x1 : list nat) (l : list (list nat)) (x2 : list nat) (xs' : list (list nat)), l = x2 :: xs' -> xs = x1 :: x2 :: xs' -> Datatypes.length (merge_pairs (x1 :: x2 :: xs')) < Datatypes.length (x1 :: x2 :: xs')
xs: list (list nat)
x1: list nat
l: list (list nat)
x2: list nat
xs': list (list nat)
teq0: l = x2 :: xs'
teq: xs = x1 :: x2 :: xs'

Datatypes.length (merge_pairs (x1 :: x2 :: xs')) < Datatypes.length (x1 :: x2 :: xs')
apply merge_pairs_length. Defined.

Note that we had to use merge_pairs_length before Coq was willing to accept our recursive definition. It is also interesting to examine the objects that are generated by this definition.

merge_iter = fun xs : list (list nat) => let (a, _) := merge_iter_terminate xs in a : list (list nat) -> list nat Arguments merge_iter xs%list_scope
merge_iter_terminate = (fun h : forall (xs : list (list nat)) (x1 : list nat) (l : list (list nat)) (x2 : list nat) (xs' : list (list nat)), l = x2 :: xs' -> xs = x1 :: x2 :: xs' -> Datatypes.length (merge_pairs (x1 :: x2 :: xs')) < Datatypes.length (x1 :: x2 :: xs') => let H : forall (xs : list (list nat)) (x1 : list nat) (l : list (list nat)) (x2 : list nat) (xs' : list (list nat)), l = x2 :: xs' -> xs = x1 :: x2 :: xs' -> Datatypes.length (merge_pairs (x1 :: x2 :: xs')) < Datatypes.length (x1 :: x2 :: xs') := h in (fun (_ : forall (xs : list (list nat)) (x1 : list nat) (l : list (list nat)) (x2 : list nat) (xs' : list (list nat)), l = x2 :: xs' -> xs = x1 :: x2 :: xs' -> Datatypes.length (merge_pairs (x1 :: x2 :: xs')) < Datatypes.length (x1 :: x2 :: xs')) (xs : list (list nat)) => let Acc_xs : Acc (ltof (list (list nat)) (fun xs0 : list (list nat) => Datatypes.length xs0)) xs := let wf_R : well_founded (ltof (list (list nat)) (fun xs0 : list (list nat) => Datatypes.length xs0)) := well_founded_ltof (list (list nat)) (fun xs0 : list (list nat) => Datatypes.length xs0) in wf_R xs in (fix hrec (xs0 : list (list nat)) (Acc_xs0 : Acc (ltof (list (list nat)) (fun xs1 : list (list nat) => Datatypes.length xs1)) xs0) {struct Acc_xs0} : {v : list nat | exists p : nat, forall k : nat, p < k -> forall def : list (list nat) -> list nat, iter (list (list nat) -> list nat) k merge_iter_F def xs0 = v} := match xs0 as l return (xs0 = l -> {v : list nat | exists p : nat, forall k : nat, p < k -> forall def : ... -> ..., iter (...) k merge_iter_F def l = v}) with | [] => fun _ : xs0 = [] => exist (fun v : list nat => exists p : nat, forall k : nat, p < k -> forall def : list ... -> list nat, iter (... -> ...) k merge_iter_F def [] = v) [] (ex_intro (fun p : nat => forall k : nat, p < k -> forall def : ... -> ..., iter (...) k merge_iter_F def [] = []) 1 (fun k : nat => match k as n return (... -> ...) with | 0 => fun h0 : 1 < 0 => False_ind (..., ...) (Nat.nlt_0_r 1 h0) | S n => (fun ... ... ... => ... : ...) n end)) | l :: l0 => (fun (x1 : list nat) (l1 : list (list nat)) (teq : xs0 = x1 :: l1) => match l1 as l2 return (l1 = l2 -> xs0 = ... -> {v : ... | ...}) with | [] => fun (_ : l1 = []) (_ : xs0 = [x1]) => exist (fun v : ... => exists p : nat, ..., ...) x1 (ex_intro (... => ...) 1 (... => ...)) | l2 :: l3 => (fun (x2 : list nat) (xs' : list ...) (teq0 : l1 = ...) (teq1 : xs0 = ...) => sig_rec (... => ...) (... => ...) (hrec ... ...)) l2 l3 end eq_refl teq) l l0 end eq_refl) xs Acc_xs) H) merge_iter_tcc : forall xs : list (list nat), {v : list nat | exists p : nat, forall k : nat, p < k -> forall def : list (list nat) -> list nat, iter (list (list nat) -> list nat) k merge_iter_F def xs = v} Arguments merge_iter_terminate xs%list_scope

With these pieces in hand, we can finally define our merge_sort function.

Definition merge_sort (xs:list nat) :=
  merge_iter (make_lists xs).

Let's test that is is working as expected:

= [1; 2; 3; 4; 5; 6; 7; 8] : list nat
= [2; 3; 7; 8] : list nat

Recursive Functions with Fuel (And, Type Classes!)

Another approach to dealing with tricky general recursive functions is to approximate them using "fuel." To illustrate this approach, we'll introduce Coq's type classes, which are similar to Haskell's, and develop some constructs for pretty-printing various data types.

A type class defines an interface for some type. In this case, we say that types A that implement the Show interface have a method named show that will convert them to a string.

Class Show (A:Type) := {
  show : A -> string
}.

Here is an instance for booleans.

Instance boolShow : Show bool := {
    show := fun (b:bool) => if b then "true" else "false"
  }.

= "true" : string
= "false" : string

Note that we cannot yet use this for natural numbers:

= ?s : Show nat where ?s : [ |- Show nat]
= (let (show) := ?Show in show) 3 : string where ?Show : [ |- Show nat]

To define a Show instance for natural numbers, let's first define a helper that shows a single digit.

Definition digit2string(d:nat) : string :=
  match d with
  | 0 => "0" | 1 => "1" | 2 => "2" | 3 => "3"
  | 4 => "4" | 5 => "5" | 6 => "6" | 7 => "7"
  | 8 => "8" | _ => "9"
  end.

Alas, it can be difficult to convince Coq that the iterated version of this helper function terminates. So, we can simply give it fuel.

Fixpoint digits'(fuel n:nat) (accum : string) : string :=
  match fuel with
  | 0 => accum
  | S fuel' =>
      match n with
      | 0 => accum
      | _ => let d := digit2string(n mod 10) in
             digits' fuel' (n / 10) (d ++ accum)
      end
  end.

Fortunately, it's sufficient to use n as the fuel for itself, since we know we won't need to divide n by 10 more than n times. We could of course use the log of n (base 10) instead, but there's no benefit to being this precise, so we'll just use n.

Definition digits (n:nat) : string :=
  match digits' n n "" with
  | "" => "0"
  | ds => ds
  end.

Here is another version written using Function and with an explicit termination measure.


forall n : nat, string -> forall n0 : nat, n = S n0 -> S n0 / 10 < S n0

forall n : nat, string -> forall n0 : nat, n = S n0 -> S n0 / 10 < S n0
n: nat
accum: string
n0: nat
teq: n = S n0

S n0 / 10 < S n0
apply Nat.div_lt; lia. Qed.

Now we can define the Show instance for nat.

Instance natShow : Show nat := {
  show := digits
}.

= "42" : string
= "12" : string

Importantly, because we used a type class, we can still show booleans.

= "true" : string

We can also define parameterized Show instances, which allow us to show data structures.

Instance pairShow (A B:Type) (showA : Show A) (showB : Show B)
  : Show (A*B) := {
    show := (fun p => "(" ++
                        (show (fst p)) ++ "," ++
                        (show (snd p)) ++
                        ")")
  }.

= "(3,4)" : string
= "(true,42)" : string
pairShow = fun (A B : Type) (showA : Show A) (showB : Show B) => {| show := fun p : A * B => "(" ++ show (fst p) ++ "," ++ show (snd p) ++ ")" |} : forall A B : Type, Show A -> Show B -> Show (A * B) Arguments pairShow (A B)%type_scope showA showB
Definition show_list {A:Type} {showA:Show A} (xs:list A) : string := ("[" ++ ((fix loop (xs:list A) : string := match xs with | nil => "]" | h::nil => show h ++ "]" | h::t => show h ++ "," ++ loop t end) xs)). Instance listShow (A:Type) (showA:Show A) : Show (list A) := { show := @show_list A showA }.
= "[1,2,3]" : string
= "[true,false]" : string
= "[(1,true),(42,false)]" : string
= "[[1,2,3],[4,5],[6]]" : string

Type classes are a powerful abstraction tool that fit very well with some proof developments. If you decide to use them, one tip is to make type classes as small and simple as possible. For example, you probably want to separate out the operations you want to use from the properties that the implementation is expected to satisfy.