How can I build a list of bytes from its specification in Coq

166 Views Asked by At

I'm trying to build a list of bytes according to the specification I do have in my context (The specification is defined based on the conjunction of nth_error functions that determine a byte value at an index (or byte values in a range of indices). For example, see the below goal and the context.

a_len, b_len : nat
a : seq.seq byte
b : seq.seq byte
H : Datatypes.length a = a_len /\           (* the length of list I'd like to build *)
    Datatypes.length b = b_len /\           (* the length of another list concatenated at the end *)
    is_true (b_len + 4 <= a_len) /\         (* added after edit *)
    is_true (1 < b_len) /\                  (* added after edit *)
    nth_error a 0 = x00 /\                         (* the value of first byte is zero *)
    (forall i : nat,                               (* then we have a bunch of x01's *)
     is_true (0 < i) /\ is_true (i < a_len - b_len - 1) ->
     nth_error a i = Some x01) /\
    nth_error a (a_len - b_len - 1) = Some x00 /\  (* next byte is zero *)
    (forall j : nat,                               (* which ends with a list that is equal to b *)
     is_true (0 <= j) /\ is_true (j < b_len) ->
     nth_error a (a_len - b_len + j) = nth_error b j)
______________________________________(1/1)
a = [x00] ++ repeat x01 (a_len - b_len - 2) ++ [x00] ++ b

I have tried to use some existing lemma like nth_error_split which is defined as:

nth_error_split :
    forall [A : Type] (l : seq.seq A) (n : nat) [a : A],
    nth_error l n = Some a ->
    exists l1 l2 : seq.seq A,
      l = (l1 ++ (a :: l2)%SEQ)%list /\ Datatypes.length l1 = n

and define some lemma like this:

Lemma two_concats_equality1: 
    forall (lb1 lb1' lb2 lb2': list byte), 
           (lb1 ++ lb2) = (lb1' ++ lb2') /\ length lb1 = length lb1' -> lb1 = lb1' /\ lb2 = lb2'.

to build the list of bytes, a, from scratch using nth_error_split and carry information along using two_concats_equality1 and do this multiple times till end. But no success yet. I couldn't even prove two_concats_equality1 lemma (just assumed to be true for the time being). I've got stuck right at the beginning of the byte repetition, nth_error a i = Some x01.

I'm wondering if this is a right approach to prove this goal. If not, please let me know what you would do.

Any comments would be highly appreciated.

Edit:

As @Yves pointed out, I'm making the following edits:

  1. Fixing the typo in the context. There is no variable n, it should be a_len.
  2. I added two more constraints in the spec that describe the relations for a_len and b_len, avoiding the false statement mentioned.
  3. Below, you can find a minimal reproducible example with the libraries imported.

From mathcomp Require Import all_ssreflect ssrnat.
From Coq Require Import Lia.
Require Import Init.Byte Coq.Lists.List.
Import ListNotations. 

Lemma build_from_spec : 
    forall (a_len b_len : nat) (a b : list byte),
    Datatypes.length a = a_len /\
    Datatypes.length b = b_len /\
    a_len >= b_len + 4 /\
    b_len >= 2 /\
    nth_error a 0 = Some x00 /\
    (forall i : nat, (0 < i) /\ (i < a_len - b_len - 1) ->
    nth_error a i = Some x01) /\
    nth_error a (a_len - b_len - 1) = Some x00 /\
    (forall j : nat, (0 <= j) /\ (j < b_len) ->
    nth_error a (a_len - b_len + j) = nth_error b j) 
    ->
    a = [x00] ++ repeat x01 (a_len - b_len - 2) ++ [x00] ++ b.
Proof.
Admitted.
1

There are 1 best solutions below

11
On BEST ANSWER

The question was edited by the original poster, changing the statement. The bottom part of this message is an answer to the original question.

To solve this problem, you need a theorem that does not (yet) exist in the library, and which I include here:

Lemma eq_from_nth_error {A : Type} (l1 l2 : list A) :
  (forall i, nth_error l1 i = nth_error l2 i) -> l1 = l2.
Proof.
elim: l1 l2 => [ | a l1 IH] [ | a' l2] //.
    by move=> abs; have := (abs 0).
  by move=> abs; have := (abs 0).
move=> cmp; congr (_ :: _).
  by have := (cmp 0) => /= [[]].
apply: IH=> i; exact (cmp i.+1).
Qed.

With this theorem, you can prove the equality by studying all possible indices i that you give as argument to nth_error. So, you introduce your hypothesis, and then you apply this theorem and, you introduce i, and you look at 5 possible cases for i:

  1. if i = 0,
  2. if 0 < i < a_len - b_len - 1
  3. if i = a_len - b_len - 1
  4. if a_len - b_len - 1 < i < a_len
  5. if a_len <= i

In ssreflect style, these cases are introduced by writing

have [/eqP i_is_0 | i_is_not_0] := boolP(i == 0).

have [i_lt_border | is_larger] := boolP(i < a_len - b_len - 1).

and so on. You will be able to complete the proof. You made your life complicated because your statement is written using the mathematical components versions of arithmetic statements: (i < a_len) is a boolean expression that is not recognized by lia, so you need to perform a lot of conversions to make things work. Here is an example problem.

Lemma arithmetic_difficulty i j : i + 3 < j - 2 -> i <= j.
Proof.
Fail lia.
rewrite -?(minusE, plusE).
move/ltP => i3j2.
apply/leP.
lia.
Qed.

So you see, I need to use rewrite theorem minusE, plusE, ltP, and leP to transform the "mathematical components" definitions of +, -, < and <= into traditional versions of these operators, before lia can solve the problem. Normally, lia should be improved so that this kind of transformation will not be needed in later versions of Coq (I am using coq.8.12 and this trickery is still needed).

A previous version of the statement was false, and prompted me to produce the following counter example:

From mathcomp Require Import all_ssreflect.

Set Implicit Arguments.
Unset Strict Implicit.
Unset Printing Implicit Defensive.

Section some_context.

Definition byte := nat.
Variables x00 x01 : byte.

Lemma Dan_problem : (forall (a_len b_len : nat)
(a b : seq byte)
(H : Datatypes.length a = a_len /\
    Datatypes.length b = b_len /\
    List.nth_error a 0 = Some x00 /\
    (forall i : nat,
    is_true (0 < i) /\ is_true (i < a_len - b_len - 1) ->
      List.nth_error a i = Some x01) /\
    List.nth_error a (a_len - b_len - 1) = Some x00 /\
    (forall j : nat,
      is_true (0 <= j) /\ is_true (j < b_len) ->
      List.nth_error a (a_len - b_len + j) = List.nth_error b j)),
a = [:: x00] ++ List.repeat x01 (a_len - b_len - 2) ++ [:: x00] ++ b) ->
   False.
Proof.
intros abs.
assert (cnd1 : forall i,  0 < i /\ i < 1 - 0 -1 ->
  List.nth_error [:: x00] i = Some x01).
  by move=> i [igt0 ilt0].
assert (cnd2 : forall j : nat,
          0 <= j /\ j < 0 ->
         List.nth_error [:: x00] (1 - 0 + j) = List.nth_error [::] j).
  by move=> j [jge0 glt0].
generalize (abs 1 0 (x00::nil) nil (conj erefl
    (conj erefl (conj erefl (conj cnd1 (conj erefl cnd2)))))).
by rewrite /=.
Qed.

This script mixes mathematical components style and vanilla coq style, this is bad taste.