Skip to content

Commit

Permalink
Merge pull request #424 from hacspec/core-improvements
Browse files Browse the repository at this point in the history
F*: core lib: misc. improvements
  • Loading branch information
W95Psp authored Jan 9, 2024
2 parents d1acf32 + cf26e24 commit baef406
Show file tree
Hide file tree
Showing 15 changed files with 577 additions and 98 deletions.
2 changes: 1 addition & 1 deletion examples/chacha20/proofs/fstar/extraction/Chacha20.fst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ let t_State = t_Array u32 (sz 16)
let chacha20_line (a b d: usize) (s: u32) (m: t_Array u32 (sz 16))
: Prims.Pure (t_Array u32 (sz 16))
(requires a <. sz 16 && b <. sz 16 && d <. sz 16)
(fun _ -> Prims.l_True) =
(fun (res:t_Array u32 (sz 16)) -> True ) =
let state:t_Array u32 (sz 16) = m in
let state:t_Array u32 (sz 16) =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize state
Expand Down
2 changes: 1 addition & 1 deletion hax-lib/proofs/fstar/extraction/Hax_lib.fst
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ let v_assume (v__formula: bool) = assume v__formula

unfold let v_exists (v__f: 'a -> Type0): Type0 = exists (x: 'a). v__f x
unfold let v_forall (v__f: 'a -> Type0): Type0 = forall (x: 'a). v__f x
unfold let implies (lhs: bool) (rhs: unit -> bool): bool = (not lhs) || rhs ()
unfold let implies (lhs: bool) (rhs: (x:unit{lhs} -> bool)): bool = (not lhs) || rhs ()
2 changes: 1 addition & 1 deletion proof-libs/fstar/core/Core.Num.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ open Rust_primitives
let impl__u8__wrapping_sub: u8 -> u8 -> u8 = sub_mod
let impl__u16__wrapping_add: u16 -> u16 -> u16 = add_mod
let impl__i32__wrapping_add: i32 -> i32 -> i32 = add_mod
val impl__i32__abs: i32 -> i32
let impl__i32__abs (a:i32{minint i32_inttype < v a}) : i32 = abs_int a

let impl__u32__wrapping_add: u32 -> u32 -> u32 = add_mod
val impl__u32__rotate_left: u32 -> u32 -> u32
Expand Down
11 changes: 9 additions & 2 deletions proof-libs/fstar/core/Core.Slice.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ val impl__chunks (x: t_Slice 'a) (cs: usize): t_Chunks 'a

let impl__iter (s: t_Slice 't): t_Slice 't = s

val impl__chunks_exact (x: t_Slice 'a) (cs: usize): t_Slice (s: t_Slice 'a {length s == cs})
val impl__chunks_exact (x: t_Slice 'a) (cs: usize):
Pure (t_Slice (t_Slice 'a))
(requires True)
(ensures (fun r -> forall i. i < v (length x) ==> length x == cs))

open Core.Ops.Index

Expand All @@ -24,4 +27,8 @@ instance impl__index t n: t_Index (t_Slice t) (int_t n)

let impl__copy_from_slice #t (x: t_Slice t) (y:t_Slice t) : t_Slice t = y

val impl__split_at #t (s: t_Slice t) (mid: usize): (t_Slice t * t_Slice t)
val impl__split_at #t (s: t_Slice t) (mid: usize): Pure (t_Slice t * t_Slice t)
(requires (v mid <= Seq.length s))
(ensures (fun (x,y) -> Seq.length x == v mid /\ Seq.length y == Seq.length s - v mid /\
x == Seq.slice s 0 (v mid) /\ y == Seq.slice s (v mid) (Seq.length s) /\
s == Seq.append x y))
13 changes: 7 additions & 6 deletions proof-libs/fstar/rust_primitives/Rust_primitives.Arrays.fst
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@ module Rust_primitives.Arrays

open Rust_primitives.Integers

type t_Slice t = s:Seq.seq t{Seq.length s <= max_usize}
type t_Array t (l:usize) = s: Seq.seq t { Seq.length s == v l }

let of_list (#t:Type) (l: list t {FStar.List.Tot.length l < maxint Lib.IntTypes.U16}): t_Slice t = Seq.seq_of_list l
let to_list (#t:Type) (s: t_Slice t): list t = Seq.seq_to_list s

let to_of_list_lemma t l = Seq.lemma_list_seq_bij l
let of_to_list_lemma t l = Seq.lemma_seq_list_bij l

let length (s: t_Slice 'a): usize = sz (Seq.length s)
let contains (#t: eqtype) (s: t_Slice t) (x: t): bool = Seq.mem x s

let map_array #n (arr: t_Array 'a n) (f: 'a -> 'b): t_Array 'b n
= FStar.Seq.map_seq_len f arr;
FStar.Seq.map_seq f arr

let createi #t l f = admit() // see issue #423

let lemma_index_concat x y i = admit() // see issue #423

let lemma_index_slice x y i = admit() // see issue #423

let eq_intro a b = admit() // see issue #423
104 changes: 104 additions & 0 deletions proof-libs/fstar/rust_primitives/Rust_primitives.Arrays.fsti
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
module Rust_primitives.Arrays

open Rust_primitives.Integers

/// Rust slices and arrays are represented as sequences
type t_Slice t = s:Seq.seq t{Seq.length s <= max_usize}
type t_Array t (l:usize) = s: Seq.seq t { Seq.length s == v l }

/// Length of a slice
let length (s: t_Slice 'a): usize = sz (Seq.length s)

/// Check whether a slice contains an item
let contains (#t: eqtype) (s: t_Slice t) (x: t): bool = Seq.mem x s

/// Converts an F* list into an array
val of_list (#t:Type) (l: list t {FStar.List.Tot.length l < maxint Lib.IntTypes.U16}):
t_Array t (sz (FStar.List.Tot.length l))
/// Converts an slice into a F* list
val to_list (#t:Type) (s: t_Slice t): list t

val map_array #n (arr: t_Array 'a n) (f: 'a -> 'b): t_Array 'b n

/// Creates an array of size `l` using a function `f`
val createi #t (l:usize) (f:(u:usize{u <. l} -> t))
: Pure (t_Array t l)
(requires True)
(ensures (fun res -> (forall i. Seq.index res (v i) == f i)))

unfold let map #p
(f:(x:'a{p x} -> 'b))
(s: t_Slice 'a {forall (i:nat). i < Seq.length s ==> p (Seq.index s i)}): t_Slice 'b
= createi (length s) (fun i -> f (Seq.index s (v i)))

/// Concatenates two slices
let concat #t (x:t_Slice t) (y:t_Slice t{range (v (length x) + v (length y)) usize_inttype}) :
r:t_Array t (length x +! length y) = Seq.append x y

/// Translate indexes of `concat x y` into indexes of `x` or of `y`
val lemma_index_concat #t (x:t_Slice t) (y:t_Slice t{range (v (length x) + v (length y)) usize_inttype}) (i:usize{i <. length x +! length y}):
Lemma (if i <. length x then
Seq.index (concat x y) (v i) == Seq.index x (v i)
else
Seq.index (concat x y) (v i) == Seq.index y (v (i -! length x)))
[SMTPat (Seq.index (concat #t x y) i)]

/// Take a subslice given `x` a slice and `i` and `j` two indexes
let slice #t (x:t_Slice t) (i:usize{i <=. length x}) (j:usize{i <=. j /\ j <=. length x}):
r:t_Array t (j -! i) = Seq.slice x (v i) (v j)

/// Translate indexes for subslices
val lemma_index_slice #t (x:t_Slice t) (i:usize{i <=. length x}) (j:usize{i <=. j /\ j <=. length x})
(k:usize{k <. j -! i}):
Lemma (Seq.index (slice x i j) (v k) == Seq.index x (v (i +! k)))
[SMTPat (Seq.index (slice x i j) (v k))]

/// Introduce bitwise equality principle for sequences
val eq_intro #t (a : Seq.seq t) (b:Seq.seq t{Seq.length a == Seq.length b}):
Lemma
(requires forall i. {:pattern Seq.index a i; Seq.index b i}
i < Seq.length a ==>
Seq.index a i == Seq.index b i)
(ensures Seq.equal a b)
[SMTPat (Seq.equal a b)]

/// Split a slice in two at index `m`
let split #t (a:t_Slice t) (m:usize{m <=. length a}):
Pure (t_Array t m & t_Array t (length a -! m))
True (ensures (fun (x,y) ->
x == slice a (sz 0) m /\
y == slice a m (length a) /\
concat #t x y == a)) =
let x = Seq.slice a 0 (v m) in
let y = Seq.slice a (v m) (Seq.length a) in
assert (Seq.equal a (concat x y));
(x,y)

let lemma_slice_append #t (x:t_Slice t) (y:t_Slice t) (z:t_Slice t):
Lemma (requires (range (v (length y) + v (length z)) usize_inttype /\
length y +! length z == length x /\
y == slice x (sz 0) (length y) /\
z == slice x (length y) (length x)))
(ensures (x == concat y z)) =
assert (Seq.equal x (concat y z))

let lemma_slice_append_3 #t (x:t_Slice t) (y:t_Slice t) (z:t_Slice t) (w:t_Slice t):
Lemma (requires (range (v (length y) + v (length z) + v (length w)) usize_inttype /\
length y +! length z +! length w == length x /\
y == slice x (sz 0) (length y) /\
z == slice x (length y) (length y +! length z) /\
w == slice x (length y +! length z) (length x)))
(ensures (x == concat y (concat z w))) =
assert (Seq.equal x (Seq.append y (Seq.append z w)))

let lemma_slice_append_4 #t (x y z w u:t_Slice t) :
Lemma (requires (range (v (length y) + v (length z) + v (length w) + v (length u)) usize_inttype /\
length y +! length z +! length w +! length u == length x /\
y == slice x (sz 0) (length y) /\
z == slice x (length y) (length y +! length z) /\
w == slice x (length y +! length z) (length y +! length z +! length w) /\
u == slice x (length y +! length z +! length w) (length x)))
(ensures (x == concat y (concat z (concat w u)))) =
assert (Seq.equal x (Seq.append y (Seq.append z (Seq.append w u))))


52 changes: 52 additions & 0 deletions proof-libs/fstar/rust_primitives/Rust_primitives.BitVectors.fst
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
module Rust_primitives.BitVectors

open FStar.Mul
open Rust_primitives.Arrays
open Rust_primitives.Integers

#set-options "--fuel 0 --ifuel 1 --z3rlimit 40"

let lemma_get_bit_bounded #t x d i = admit() // see issue #423

let lemma_get_bit_bounded' #t x d = admit() // see issue #423

let pow2_minus_one_mod_lemma1 (n: nat) (m: nat {m < n})
: Lemma (((pow2 n - 1) / pow2 m) % 2 == 1)
= let d: pos = n - m in
Math.Lemmas.pow2_plus m d;
Math.Lemmas.lemma_div_plus (-1) (pow2 d) (pow2 m);
if d > 0 then Math.Lemmas.pow2_double_mult (d-1)

let pow2_minus_one_mod_lemma2 (n: nat) (m: nat {n <= m})
: Lemma (((pow2 n - 1) / pow2 m) % 2 == 0)
= Math.Lemmas.pow2_le_compat m n;
Math.Lemmas.small_div (pow2 n - 1) (pow2 m)

let get_bit_pow2_minus_one #t n nth
= reveal_opaque (`%get_bit) (get_bit (mk_int #t (pow2 n - 1)) nth);
if v nth < n then pow2_minus_one_mod_lemma1 n (v nth)
else pow2_minus_one_mod_lemma2 n (v nth)

let get_bit_pow2_minus_one_i32 x nth
= let n = Some?.v (mask_inv_opt x) in
assume (pow2 n - 1 == x); // see issue #423
mk_int_equiv_lemma #i32_inttype x;
get_bit_pow2_minus_one #i32_inttype n nth

let get_bit_pow2_minus_one_u32 x nth
= let n = Some?.v (mask_inv_opt x) in
assume (pow2 n - 1 == x); // see issue #423
mk_int_equiv_lemma #u32_inttype x;
get_bit_pow2_minus_one #u32_inttype n nth

let get_bit_pow2_minus_one_u16 x nth
= let n = Some?.v (mask_inv_opt x) in
assume (pow2 n - 1 == x); // see issue #423
mk_int_equiv_lemma #u16_inttype x;
get_bit_pow2_minus_one #u16_inttype n nth

let get_bit_pow2_minus_one_u8 t x nth
= let n = Some?.v (mask_inv_opt x) in
assume (pow2 n - 1 == x); // see issue #423
mk_int_equiv_lemma #u8_inttype x;
get_bit_pow2_minus_one #u8_inttype n nth
114 changes: 114 additions & 0 deletions proof-libs/fstar/rust_primitives/Rust_primitives.BitVectors.fsti
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
module Rust_primitives.BitVectors

open FStar.Mul
open Rust_primitives.Arrays
open Rust_primitives.Integers

/// Number of bits carried by an integer of type `t`
type bit_num t = d: nat {d > 0 /\ d <= bits t /\ (signed t ==> d <= bits t)}

/// States that `x` is a positive integer that fits in `d` bits
type bounded #t (x:int_t t) (d:bit_num t) =
v x >= 0 /\ v x < pow2 d

/// Integer of type `t` that carries at most `d` bits
type int_t_d t (d: bit_num t) =
n: int_t t {bounded n d}

/// If `x` fits in `d` bits, then upper bits are zero
val lemma_get_bit_bounded #t (x:int_t t) (d:bit_num t) (i:usize):
Lemma ((bounded x d /\ v i >= d /\ v i < bits t) ==>
get_bit x i == 0)
[SMTPat (get_bit #t x i); SMTPat (bounded x d)]

/// If upper bits of `x` are zero, then `x` is bounded accordingly
val lemma_get_bit_bounded' #t (x:int_t t) (d:bit_num t):
Lemma (requires forall i. v i > d ==> get_bit x i == 0)
(ensures bounded x d)

/// A bit vector is a partial map from indexes to bits
type bit_vec (len: nat) = i:nat {i < len} -> bit

/// Transform an array of integers to a bit vector
#push-options "--fuel 0 --ifuel 1 --z3rlimit 50"
let bit_vec_of_int_arr (#n: inttype) (#len: usize)
(arr: t_Array (int_t n) len)
(d: bit_num n): bit_vec (v len * d)
= fun i -> get_bit (Seq.index arr (i / d)) (sz (i % d))
#pop-options

/// Transform an array of `nat`s to a bit vector
#push-options "--fuel 0 --ifuel 1 --z3rlimit 50"
let bit_vec_of_nat_arr (#len: usize)
(arr: t_Array nat len)
(d: nat)
: bit_vec (v len * d)
= fun i -> get_bit_nat (Seq.index arr (i / d)) (i % d)
#pop-options

/// Bit-wise semantics of `2^n-1`
val get_bit_pow2_minus_one #t
(n: nat {pow2 n - 1 <= maxint t})
(nth: usize {v nth < bits t})
: Lemma ( get_bit (mk_int #t (pow2 n - 1)) nth
== (if v nth < n then 1 else 0))

/// Log2 table
unfold let mask_inv_opt =
function | 0 -> Some 0
| 1 -> Some 1
| 3 -> Some 2
| 7 -> Some 3
| 15 -> Some 4
| 31 -> Some 5
| 63 -> Some 6
| 127 -> Some 7
| 255 -> Some 8
| 511 -> Some 9
| 1023 -> Some 10
| 2047 -> Some 11
| 4095 -> Some 12
| _ -> None

/// Specialized `get_bit_pow2_minus_one` lemmas with SMT patterns
/// targetting machine integer literals of type `i32`
val get_bit_pow2_minus_one_i32
(x: int {x < pow2 31 /\ Some? (mask_inv_opt x)}) (nth: usize {v nth < 32})
: Lemma ( get_bit (FStar.Int32.int_to_t x) nth
== (if v nth < Some?.v (mask_inv_opt x) then 1 else 0))
[SMTPat (get_bit (FStar.Int32.int_to_t x) nth)]

/// Specialized `get_bit_pow2_minus_one` lemmas with SMT patterns
/// targetting machine integer literals of type `u32`
val get_bit_pow2_minus_one_u32
(x: int {x < pow2 32 /\ Some? (mask_inv_opt x)}) (nth: usize {v nth < 32})
: Lemma ( get_bit (FStar.UInt32.uint_to_t x) nth
== (if v nth < Some?.v (mask_inv_opt x) then 1 else 0))
[SMTPat (get_bit (FStar.UInt16.uint_to_t x) nth)]

/// Specialized `get_bit_pow2_minus_one` lemmas with SMT patterns
/// targetting machine integer literals of type `u16`
val get_bit_pow2_minus_one_u16
(x: int {x < pow2 16 /\ Some? (mask_inv_opt x)}) (nth: usize {v nth < 16})
: Lemma ( get_bit (FStar.UInt16.uint_to_t x) nth
== (if v nth < Some?.v (mask_inv_opt x) then 1 else 0))
[SMTPat (get_bit (FStar.UInt16.uint_to_t x) nth)]

/// Specialized `get_bit_pow2_minus_one` lemmas with SMT patterns
/// targetting machine integer literals of type `u8`
val get_bit_pow2_minus_one_u8
// We use `Lib.IntTypes` (Hacl*'s library): every operation on
// integers is polymorphic in integer types. There is a one to one
// correspondence between F*'s machine integers (UInt16, UInt64...)
// and `Lib.IntTypes.inttype`'s variants (U16, U32...), but for `U8`
// and `U1`. Bits (`U1`) and `u8`s are using the same
// representation: `U8`. Thus, sometimes F* picks the wrong
// `inttype`. This is the reason for the refined type `t` below,
// which appears only on this version of the specialized
// `get_bit_pow2_minus_one` lemma.
(t: _ {t == u8_inttype})
(x: int {x < pow2 8 /\ Some? (mask_inv_opt x)}) (nth: usize {v nth < 8})
: Lemma ( get_bit #t (FStar.UInt8.uint_to_t x) nth
== (if v nth < Some?.v (mask_inv_opt x) then 1 else 0))
[SMTPat (get_bit #t (FStar.UInt8.uint_to_t x) nth)]

Loading

0 comments on commit baef406

Please sign in to comment.