Implementing vector addition in Coq - vector

Implementing vector addition in some of the dependently typed languages (such as Idris) is fairly straightforward. As per the example on Wikipedia:
import Data.Vect
%default total
pairAdd : Num a => Vect n a -> Vect n a -> Vect n a
pairAdd Nil Nil = Nil
pairAdd (x :: xs) (y :: ys) = x + y :: pairAdd xs ys
(Note how Idris' totality checker automatically infers that addition of Nil and non-Nil vectors is a logical impossibility.)
I am trying to implement the equivalent functionality in Coq, using a custom vector implementation, albeit very similar to the one provided in the official Coq libraries:
Set Implicit Arguments.
Inductive vector (X : Type) : nat -> Type :=
| vnul : vector X 0
| vcons {n : nat} (h : X) (v : vector X n) : vector X (S n).
Arguments vnul [X].
Fixpoint vpadd {n : nat} (v1 v2 : vector nat n) : vector nat n :=
match v1 with
| vnul => vnul
| vcons _ x1 v1' =>
match v2 with
| vnul => False_rect _ _
| vcons _ x2 v2' => vcons (x1 + x2) (vpadd v1' v2')
end
end.
When Coq attempts to check vpadd, it yields the following error:
Error:
In environment
vpadd : forall n : nat, vector nat n -> vector nat n -> vector nat n
[... other types]
n0 : nat
v1' : vector nat n0
n1 : nat
v2' : vector nat n1
The term "v2'" has type "vector nat n1" while it is expected to have type "vector nat n0".
Note that, I use False_rect to specify the impossible case, otherwise the totality check wouldn't pass. However, for some reason the type checker doesn't manage to unify n0 with n1.
What am I doing wrong?

It's not possible to implement this function so easily in plain Coq: you need to rewrite your function using the convoy pattern. There was a similar question posted a while ago about this. The idea is that you need to make your match return a function in order to propagate the relation between the indices:
Set Implicit Arguments.
Inductive vector (X : Type) : nat -> Type :=
| vnul : vector X 0
| vcons {n : nat} (h : X) (v : vector X n) : vector X (S n).
Arguments vnul [X].
Definition vhd (X : Type) n (v : vector X (S n)) : X :=
match v with
| vcons _ h _ => h
end.
Definition vtl (X : Type) n (v : vector X (S n)) : vector X n :=
match v with
| vcons _ _ tl => tl
end.
Fixpoint vpadd {n : nat} (v1 v2 : vector nat n) : vector nat n :=
match v1 in vector _ n return vector nat n -> vector nat n with
| vnul => fun _ => vnul
| vcons _ x1 v1' => fun v2 => vcons (x1 + vhd v2) (vpadd v1' (vtl v2))
end v2.

Related

How to represent kleisli composition of substitutions in abstract trees

Context: I have been trying to implement the unification algorithm (the algorithm to find the most general unifier of two abstract syntax trees). Since a unifier is a substitution, algorithm requires defining composition of substitutions.
To be specific, given a type treeSigma dependent on another type X, a substitution is a function of type:
X -> treeSigma X
and the function substitute takes a substitution as an input and has type
substitute: (X-> (treeSigma X))-> (treeSigma X) -> (treeSigma X)
I need to define a function to compose two substitutions:
compose_kleisli (rho1 rho2: X->(treeSigma X)) : (treeSigma X) := ...
such that,
forall tr: treeSigma X,
substitute (compose_kleisli rho1 rho2) tr = substitute rho1 (substitute rho2 tr).
I am fairly new to coq and have been stuck with defining this composition.
How can I define this composition?
I tried to define it using Record like this:
Record compose {X s} (rho1 rho2: X-> treeSigma X):= mkCompose{
RHO: X-> treeSigma X;
CONDITION: forall t, substitute RHO t = substitute rho2 (substitute rho1 t)
}.
but along with this, I would need to prove the result that the composition can be defined for any two substitutions. Something like:
Theorem composeTotal: forall {X s} (rho1 rho2: X-> treeSigma s X), exists rho3,
forall t, substitute rho3 t = substitute rho2 (substitute rho1 t).
Proving this would require a construction of rho3 which circles back to the same problem of defining compose.
treeSigma is defined as:
(* Signature *)
Record sigma: Type := mkSigma {
symbol : Type;
arity : symbol -> nat
}.
Record sigmaLeaf (s:sigma): Type := mkLeaf {
cLeaf: symbol s;
condLeaf: arity s cLeaf = 0
}.
Record sigmaNode (s:sigma): Type := mkNode {
fNode: symbol s;
condNode: arity s fNode <> 0
}.
(* Sigma Algebra *)
Record sigAlg (s:sigma) (X:Type) := mkAlg {
Carrier: Type;
meaning: forall f:(sigmaNode s), (Vector.t Carrier (arity s (fNode s f))) -> Carrier;
meanLeaf: forall f:(sigmaLeaf s), Vector.t Carrier 0 -> Carrier
}.
(* Abstract tree on arbitrary signature. *)
Inductive treeSigma (s:sigma) (X:Type):=
| VAR (x:X)
| LEAF (c: sigmaLeaf s)
| NODE (f: sigmaNode s) (sub: Vector.t (treeSigma s X) (arity s (fNode s f)) ).
(* Defining abstract syntax as a sigma algebra. *)
Definition meanTreeNode {s X} (f:sigmaNode s) (sub: Vector.t (treeSigma s X) (s.(arity)
(fNode s f))): treeSigma s X:= NODE s X f sub.
Definition meanTreeLeaf {s X} (c:sigmaLeaf s) (sub: Vector.t (treeSigma s X) 0) := LEAF s X c.
Definition treeSigAlg {s X} := mkAlg s X (treeSigma s X) meanTreeNode meanTreeLeaf.
The substitution function is defined as:
Fixpoint homoSigma1 {X:Type} {s} (A: sigAlg s X) (rho: X-> (Carrier s X A))
(wft: (treeSigma s X)) {struct wft}: (Carrier s X A) :=
match wft with
| VAR _ _ x => rho x
| LEAF _ _ c => meanLeaf s X A c []
| NODE _ _ f l2 => meanNode s X A f (
(fix homoSigVec k (l2:Vector.t _ k):= match l2 with
| [] => []
| t::l2s => (homoSigma1 A rho t):: (homoSigVec (vlen _ l2s) l2s)
end)
(arity s (fNode s f)) l2)
end.
Definition substitute {X s} (rho: X-> treeSigma s X) (t: treeSigma s X) := #homoSigma1 X s treeSigAlg rho t.
To be particular, a substitution is the homomorphic extension of rho (which is a variable valuation).
Definitions like this are challenging to work with because the tree type occurs recursively inside of another inductive type. Coq has trouble generating induction principles for these types on its own, so you need to help it a little bit. Here is a possible solution, for a slightly simplified set up:
Require Import Coq.Vectors.Vector.
Import VectorNotations.
Set Implicit Arguments.
Unset Strict Implicit.
Unset Printing Implicit Defensive.
Section Dev.
Variable symbol : Type.
Variable arity : symbol -> nat.
Record alg := Alg {
alg_sort :> Type;
alg_op : forall f : symbol, Vector.t alg_sort (arity f) -> alg_sort;
}.
Arguments alg_op {_} f _.
(* Turn off the automatic generation of induction principles.
This tree type does not distinguish between leaves and nodes,
since they only differ in their arity. *)
Unset Elimination Schemes.
Inductive treeSigma (X:Type) :=
| VAR (x:X)
| NODE (f: symbol) (args : Vector.t (treeSigma X) (arity f)).
Arguments NODE {X} _ _.
Set Elimination Schemes.
(* Manual definition of a custom induction principle for treeSigma.
HNODE is the inductive case for the NODE constructor; the vs argument is
saying that the induction hypothesis holds for each tree in the vector of
arguments. *)
Definition treeSigma_rect (X : Type) (T : treeSigma X -> Type)
(HVAR : forall x, T (VAR x))
(HNODE : forall f (ts : Vector.t (treeSigma X) (arity f))
(vs : Vector.fold_right (fun t V => T t * V)%type ts unit),
T (NODE f ts)) :
forall t, T t :=
fix loopTree (t : treeSigma X) : T t :=
match t with
| VAR x => HVAR x
| NODE f ts =>
let fix loopVector n (ts : Vector.t (treeSigma X) n) :
Vector.fold_right (fun t V => T t * V)%type ts unit :=
match ts with
| [] => tt
| t :: ts => (loopTree t, loopVector _ ts)
end in
HNODE f ts (loopVector (arity f) ts)
end.
Definition treeSigma_ind (X : Type) (T : treeSigma X -> Prop) :=
#treeSigma_rect X T.
Definition treeSigma_alg (X:Type) : alg := {|
alg_sort := treeSigma X;
alg_op := #NODE X;
|}.
Fixpoint homoSigma {X : Type} {Y : alg} (ρ : X -> Y) (t : treeSigma X) : Y :=
match t with
| VAR x => ρ x
| NODE f xs => alg_op f (Vector.map (homoSigma ρ) xs)
end.
Definition substitute X (ρ : X -> treeSigma X) (t : treeSigma X) : treeSigma X :=
#homoSigma X (treeSigma_alg X) ρ t.
(* You can define composition simply by applying using substitution. *)
Definition compose X (ρ1 ρ2 : X -> treeSigma X) : X -> treeSigma X :=
fun x => substitute ρ1 (ρ2 x).
(* The property you are looking for follows by induction on the tree. Note
that this requires a nested induction on the vector of arguments. *)
Theorem composeP X (ρ1 ρ2 : X -> treeSigma X) t :
substitute (compose ρ1 ρ2) t = substitute ρ1 (substitute ρ2 t).
Proof.
unfold compose, substitute.
induction t as [x|f ts IH]; trivial.
simpl; f_equal.
induction ts as [|n t ts IH']; trivial.
simpl.
destruct IH as [e IH].
rewrite e.
f_equal.
now apply IH'.
Qed.
End Dev.
In order to do this you need to use the operations of the monad, typically:
Set Implicit Arguments.
Unset Strict Implicit.
Unset Printing Implicit Defensive.
Section MonadKleisli.
(* Set Universe Polymorphism. // Needed for real use cases *)
Variable (M : Type -> Type).
Variable (Ma : forall A B, (A -> B) -> M A -> M B).
Variable (η : forall A, A -> M A).
Variable (μ : forall A, M (M A) -> M A).
(* Compose: o^* *)
Definition oStar A B C (f : A -> M B) (g: B -> M C) : A -> M C :=
fun x => μ (Ma g (f x)).
(* Bind *)
Definition bind A B (x : M A) (f : A -> M B) : M B := oStar (fun _ => x) f tt.
End MonadKleisli.
Depending on how you organize your definitions, proving your desired properties will likely require functional extensionality, not a big deal usually but something to keep in ind.

Idris: proof about concatenation of vectors

Assume I have the following idris source code:
module Source
import Data.Vect
--in order to avoid compiler confusion between Prelude.List.(++), Prelude.String.(++) and Data.Vect.(++)
infixl 0 +++
(+++) : Vect n a -> Vect m a -> Vect (n+m) a
v +++ w = v ++ w
--NB: further down in the question I'll assume this definition isn't needed because the compiler
-- will have enough context to disambiguate between these and figure out that Data.Vect.(++)
-- is the "correct" one to use.
lemma : reverse (n :: ns) +++ (n :: ns) = reverse ns +++ (n :: n :: ns)
lemma {ns = []} = Refl
lemma {ns = n' :: ns} = ?lemma_rhs
As shown, the base case for lemma is trivially Refl. But I can't seem to find a way to prove the inductive case: the repl "just" spits out the following
*source> :t lemma_rhs
phTy : Type
n1 : phTy
len : Nat
ns : Vect len phTy
n : phTy
-----------------------------------------
lemma_rhs : Data.Vect.reverse, go phTy
(S (S len))
(n :: n1 :: ns)
[n1, n]
ns ++
n :: n1 :: ns =
Data.Vect.reverse, go phTy (S len) (n1 :: ns) [n1] ns ++
n :: n :: n1 :: ns
I understand that phTy stands for "phantom type", the implicit type of the vectors I'm considering. I also understand that go is the name of the function defined in the where clause for the definition of the library function reverse.
Question
How can I continue the proof? Is my inductive strategy sound? Is there a better one?
Context
This has came up in one of my toy projects, where I try to define arbitrary tensors; specifically, this seems to be needed in order to define "full index contraction". I'll elaborate a little bit on that:
I define tensors in a way that's roughly equivalent to
data Tensor : (rank : Nat) -> (shape : Vector rank Nat) -> Type where
Scalar : a -> Tensor Z [] a
Vector : Vect n (Tensor rank shape a) -> Tensor (S rank) (n :: shape) a
glossing over the rest of the source code (since it isn't relevant, and it's quite long and uninteresting as of now), I was able to define the following functions
contractIndex : Num a =>
Tensor (r1 + (2 + r2)) (s1 ++ (n :: n :: s2)) a ->
Tensor (r1 + r2) (s1 ++ s2) a
tensorProduct : Num a =>
Tensor r1 s1 a ->
Tensor r2 s2 a ->
Tensor (r1 + r2) (s1 ++ s2) a
contractProduct : Num a =>
Tensor (S r1) s1 a ->
Tensor (S r2) ((last s1) :: s2) a ->
Tensor (r1 + r2) ((take r1 s1) ++ s2) a
and I'm working on this other one
fullIndexContraction : Num a =>
Tensor r (reverse ns) a ->
Tensor r ns a ->
Tensor 0 [] a
fullIndexContraction {r = Z} {ns = []} t s = t * s
fullIndexContraction {r = S r} {ns = n :: ns} t s = ?rhs
that should "iterate contractProduct as much as possible (that is, r times)"; equivalently, it could be possible to define it as tensorProduct composed with as many contractIndex as possible (again, that amount should be r).
I'm including all this becuse maybe it's easier to just solve this problem without proving the lemma above: if that were the case, I'd be fully satisfied as well. I just thought the "shorter" version above might be easier to deal with, since I'm pretty sure I'll be able to figure out the missing pieces myself.
The version of idris i'm using is 1.3.2-git:PRE (that's what the repl says when invoked from the command line).
Edit: xash's answer covers almost everything, and I was able to write the following functions
nreverse_id : (k : Nat) -> nreverse k = k
contractAllIndices : Num a =>
Tensor (nreverse k + k) (reverse ns ++ ns) a ->
Tensor Z [] a
contractAllProduct : Num a =>
Tensor (nreverse k) (reverse ns) a ->
Tensor k ns a ->
Tensor Z []
I also wrote a "fancy" version of reverse, let's call it fancy_reverse, that automatically rewrites nreverse k = k in its result. So I tried to write a function that doesn't have nreverse in its signature, something like
fancy_reverse : Vect n a -> Vect n a
fancy_reverse {n} xs =
rewrite sym $ nreverse_id n in
reverse xs
contract : Num a =>
{auto eql : fancy_reverse ns1 = ns2} ->
Tensor k ns1 a ->
Tensor k ns2 a ->
Tensor Z [] a
contract {eql} {k} {ns1} {ns2} t s =
flip contractAllProduct s $
rewrite sym $ nreverse_id k in
?rhs
now, the inferred type for rhs is Tensor (nreverse k) (reverse ns2) and I have in scope a rewrite rule for k = nreverse k, but I can't seem to wrap my head around how to rewrite the implicit eql proof to make this type check: am I doing something wrong?
The prelude Data.Vect.reverse is hard to reason about, because AFAIK the go helper function won't be resolved in the typechecker. The usual approach is to define oneself an easier reverse that doesn't need rewrite in the type level. Like here for example:
%hide Data.Vect.reverse
nreverse : Nat -> Nat
nreverse Z = Z
nreverse (S n) = nreverse n + 1
reverse : Vect n a -> Vect (nreverse n) a
reverse [] = []
reverse (x :: xs) = reverse xs ++ [x]
lemma : {xs : Vect n a} -> reverse (x :: xs) = reverse xs ++ [x]
lemma = Refl
As you can see, this definition is straight-forward enough, that this equivalent lemma can be solved without further work. Thus you can probably just match on the reverse ns in fullIndexContraction like in this example:
data Foo : Vect n Nat -> Type where
MkFoo : (x : Vect n Nat) -> Foo x
foo : Foo a -> Foo (reverse a) -> Nat
foo (MkFoo []) (MkFoo []) = Z
foo (MkFoo $ x::xs) (MkFoo $ reverse xs ++ [x]) =
x + foo (MkFoo xs) (MkFoo $ reverse xs)
To your comment: first, len = nreverse len must sometimes be used, but if you had rewrite on the type level (through the usual n + 1 = 1 + n shenanigans) you had the same problem (if not even with more complicated proofs, but this is just a guess.)
vectAppendAssociative is actually enough:
lemma2 : Main.reverse (n :: ns1) ++ ns2 = Main.reverse ns1 ++ (n :: ns2)
lemma2 {n} {ns1} {ns2} = sym $ vectAppendAssociative (reverse ns1) [n] ns2

mutual recursion on an inductive type and nat

Consider this example:
Inductive T :=
| foo : T
| bar : nat -> T -> T.
Fixpoint evalT (t:T) {struct t} : nat :=
match t with
| foo => 1
| bar n x => evalBar x n
end
with evalBar (x:T) (n:nat) {struct n} : nat :=
match n with
| O => 0
| S n' => (evalT x) + (evalBar x n')
end.
Coq rejects it with an error: Recursive call to evalBar has principal argument equal to "n" instead of "x".
I understand that termination checker got confused by two unrelated inductive types (T and nat). However, it looks like the function I am trying to define will indeed terminate. How can I make Coq accept it?
Another solution is to use a nested fixpoint.
Fixpoint evalT (t:T) {struct t} : nat :=
match t with
| foo => 1
| bar n x => let fix evalBar n {struct n} :=
match n with
| 0 => 0
| S n' => Nat.add (evalT x) (evalBar n')
end
in evalBar n
end.
The important point is to remove the argument x from evalBar. Thus the recursive call to evalT is done on the x from bar n x, not the x given as an argument to evalBar, and thus the termination checker can validate the definition of evalT.
This is the same idea that makes the version with nat_rec proposed in another answer work.
One solution I found is to use nat_rec instead of evalBar:
Fixpoint evalT (t:T) {struct t} : nat :=
match t with
| foo => 1
| bar n x => #nat_rec _ 0 (fun n' t' => (evalT x) + t') n
end.
It works but I wish I could hide nat_rec under evalBar definition to hide details. In my real project, such construct is used several times.

Well founded recursion in Coq

I am trying to write a function for computing natural division in Coq and I am having some trouble defining it since it is not structural recursion.
My code is:
Inductive N : Set :=
| O : N
| S : N -> N.
Inductive Bool : Set :=
| True : Bool
| False : Bool.
Fixpoint sum (m :N) (n : N) : N :=
match m with
| O => n
| S x => S ( sum x n)
end.
Notation "m + n" := (sum m n) (at level 50, left associativity).
Fixpoint mult (m :N) (n : N) : N :=
match m with
| O => O
| S x => n + (mult x n)
end.
Notation "m * n" := (mult m n) (at level 40, left associativity).
Fixpoint pred (m : N) : N :=
match m with
| O => S O
| S x => x
end.
Fixpoint resta (m:N) (n:N) : N :=
match n with
| O => m
| S x => pred (resta m x)
end.
Notation "m - x" := (resta m x) (at level 50, left associativity).
Fixpoint leq_nat (m : N) (n : N) : Bool :=
match m with
| O => True
| S x => match n with
| O => False
| S y => leq_nat x y
end
end.
Notation "m <= n" := (leq_nat m n) (at level 70).
Fixpoint div (m : N) (n : N) : N :=
match n with
| O => O
| S x => match m <= n with
| False => O
| True => pred (div (m-n) n)
end
end.
As you can see, Coq does not like my function div, it says
Error: Cannot guess decreasing argument of fix.
How can I supply in Coq a termination proof for this function? I can prove that if n>O ^ n<=m -> (m-n) < m.
The simplest strategy in this case is probably to use the Program extension together with a measure. You will then have to provide a proof that the arguments used in the recursive call are smaller than the top level ones according to the measure.
Require Coq.Program.Tactics.
Require Coq.Program.Wf.
Fixpoint toNat (m : N) : nat :=
match m with O => 0 | S n => 1 + (toNat n) end.
Program Fixpoint div (m : N) (n : N) {measure (toNat m)}: N :=
match n with
| O => O
| S x => match m <= n with
| False => O
| True => pred (div (m-n) n)
end
end.
Next Obligation.
(* your proof here *)
Although gallais's answer is definitely the way to go in general, I should point out that we can define division on the natural numbers in Coq as a simple fixpoint. Here, I'm using the definition of nat in the standard library for simplicity.
Fixpoint minus (n m : nat) {struct n} : nat :=
match n, m with
| S n', S m' => minus n' m'
| _, _ => n
end.
Definition leq (n m : nat) : bool :=
match minus n m with
| O => true
| _ => false
end.
Fixpoint div (n m : nat) {struct n} : nat :=
match m with
| O => O
| S m' =>
if leq (S m') n then
match n with
| O => O (* Impossible *)
| S n' => S (div (minus n' m') (S m'))
end
else O
end.
Compute div 6 3.
Compute div 7 3.
Compute div 9 3.
The definition of minus is essentially the one from the standard library. Notice on the second branch of that definition we return n. Thanks to this trick, Coq's termination checker can detect that minus n' m' is structurally smaller than S n', which allows us to perform the recursive call to div.
There's actually an even simpler way of doing this, although a bit harder to understand: you can check whether the divisor is smaller and perform the recursive call in a single step.
(* Divide n by m + 1 *)
Fixpoint div'_aux n m {struct n} :=
match minus n m with
| O => O
| S n' => S (div'_aux n' m)
end.
Definition div' n m :=
match m with
| O => O (* Arbitrary *)
| S m' => div'_aux n m'
end.
Compute div' 6 3.
Compute div' 7 3.
Compute div' 9 3.
Once again, because of the form of the minus function, Coq's termination checker knows that n' in the second branch of div'_aux is a valid argument to a recursive call. Notice also that div'_aux is dividing by m + 1.
Of course, this whole thing relies on a clever trick that requires understanding the termination checker in detail. In general, you have to resort to well-founded recursion, as gallais showed.

How to get an induction principle for nested fix

I am working with a function that searches through a range of values.
Require Import List.
(* Implementation of ListTest omitted. *)
Definition ListTest (l : list nat) := false.
Definition SearchCountList n :=
(fix f i l := match i with
| 0 => ListTest (rev l)
| S i1 =>
(fix g j l1 := match j with
| 0 => false
| S j1 =>
if f i1 (j :: l1)
then true
else g j1 l1
end) (n + n) (i :: l)
end) n nil
.
I want to be able to reason about this function.
However, I can't seem to get coq's built-in induction principle facilities to work.
Functional Scheme SearchCountList := Induction for SearchCountList Sort Prop.
Error: GRec not handled
It looks like coq is set up for handling mutual recursion, not nested recursion. In this case, I have essentially 2 nested for loops.
However, translating to mutual recursion isn't so easy either:
Definition SearchCountList_Loop :=
fix outer n i l {struct i} :=
match i with
| 0 => ListTest (rev l)
| S i1 => inner n i1 (n + n) (i :: l)
end
with inner n i j l {struct j} :=
match j with
| 0 => false
| S j1 =>
if outer n i (j :: l)
then true
else inner n i j1 l
end
for outer
.
but that yields the error
Recursive call to inner has principal argument equal to
"n + n" instead of "i1".
So, it looks like I would need to use measure to get it to accept the definition directly. It is confused that I reset j sometimes. But, in a nested set up, that makes sense, since i has decreased, and i is the outer loop.
So, is there a standard way of handling nested recursion, as opposed to mutual recursion? Are there easier ways to reason about the cases, not involving making separate induction theorems? Since I haven't found a way to generate it automatically, I guess I'm stuck with writing the induction principle directly.
There's a trick for avoiding mutual recursion in this case: you can compute f i1 inside f and pass the result to g.
Fixpoint g (f_n_i1 : list nat -> bool) (j : nat) (l1 : list nat) : bool :=
match j with
| 0 => false
| S j1 => if f_n_i1 (j :: l1) then true else g f_n_i1 j1 l1
end.
Fixpoint f (n i : nat) (l : list nat) : bool :=
match i with
| 0 => ListTest (rev l)
| S i1 => g (f n i1) (n + n) (i :: l)
end.
Definition SearchCountList (n : nat) : bool := f n n nil.
Are you sure simple induction wouldn't have been enough in the original code? What about well founded induction?

Resources