diff --git a/code/aes/Hacl.Impl.AES.Core.fsti b/code/aes/Hacl.Impl.AES.Core.fsti index e62804da03..ba928a916a 100644 --- a/code/aes/Hacl.Impl.AES.Core.fsti +++ b/code/aes/Hacl.Impl.AES.Core.fsti @@ -77,7 +77,7 @@ val load_block0: -> st: state m -> b: lbuffer uint8 16ul -> Stack unit - (requires (fun h -> live h st /\ live h b)) + (requires (fun h -> live h st /\ live h b /\ disjoint st b)) (ensures (fun h0 _ h1 -> modifies1 st h0 h1)) inline_for_extraction noextract diff --git a/code/aes/Hacl.Impl.AES.CoreNI.fst b/code/aes/Hacl.Impl.AES.CoreNI.fst index 6e4e09d767..9ba6c264be 100644 --- a/code/aes/Hacl.Impl.AES.CoreNI.fst +++ b/code/aes/Hacl.Impl.AES.CoreNI.fst @@ -1,15 +1,18 @@ module Hacl.Impl.AES.CoreNI +open FStar.Mul open FStar.HyperStack open FStar.HyperStack.All open Lib.IntTypes open Lib.Buffer open Lib.IntVector +open Lib.ByteSequence -module BSeq = Lib.ByteSequence module ST = FStar.HyperStack.ST module LSeq = Lib.Sequence +#set-options "--z3rlimit 30 --max_fuel 0 --max_ifuel 0" + inline_for_extraction noextract let vec128 = vec_t U128 1 inline_for_extraction noextract @@ -55,7 +58,7 @@ val load_block0: (requires (fun h -> live h st /\ live h b /\ disjoint st b)) (ensures (fun h0 _ h1 -> modifies1 st h0 h1 /\ (let st' = as_seq h1 st in - vec_v (LSeq.index st' 0) == BSeq.uints_from_bytes_le (as_seq h0 b) /\ + vec_v (LSeq.index st' 0) == uints_from_bytes_le (as_seq h0 b) /\ LSeq.index st' 1 == LSeq.index (as_seq h0 st) 1 /\ LSeq.index st' 2 == LSeq.index (as_seq h0 st) 2 /\ LSeq.index st' 3 == LSeq.index (as_seq h0 st) 3))) @@ -70,7 +73,7 @@ val load_key1: Stack unit (requires (fun h -> live h k /\ live h b)) (ensures (fun h0 _ h1 -> modifies1 k h0 h1 /\ - vec_v (LSeq.index (as_seq h1 k) 0) == BSeq.uints_from_bytes_le (as_seq h0 b))) + vec_v (LSeq.index (as_seq h1 k) 0) == uints_from_bytes_le (as_seq h0 b))) let load_key1 k b = k.(size 0) <- vec_load_le U128 1 b @@ -83,7 +86,7 @@ val load_nonce: (requires (fun h -> live h n /\ live h b)) (ensures (fun h0 _ h1 -> modifies1 n h0 h1 /\ vec_v (LSeq.index (as_seq h1 n) 0) == - BSeq.uints_from_bytes_le (LSeq.concat (as_seq h0 b) (LSeq.create 4 (u8 0))))) + uints_from_bytes_le (LSeq.concat (as_seq h0 b) (LSeq.create 4 (u8 0))))) let load_nonce n b = let h0 = ST.get() in @@ -99,6 +102,154 @@ let load_nonce n b = n.(size 0) <- vec_load_le U128 1 nb; pop_frame() +val lemma_distr_pow (a b:int) (c d:nat) : Lemma ((a + b * pow2 c) * pow2 d = a * pow2 d + b * pow2 (c + d)) +let lemma_distr_pow a b c d = + calc (==) { + (a + b * pow2 c) * pow2 d; + (==) { Math.Lemmas.distributivity_add_left a (b * pow2 c) (pow2 d) } + a * pow2 d + b * pow2 c * pow2 d; + (==) { Math.Lemmas.paren_mul_right b (pow2 c) (pow2 d); Math.Lemmas.pow2_plus c d } + a * pow2 d + b * pow2 (c + d); + } + +val lemma_distr_pow_pow (a:int) (b:nat) (c:int) (d e:nat) : + Lemma ((a * pow2 b + c * pow2 d) * pow2 e = a * pow2 (b + e) + c * pow2 (d + e)) +let lemma_distr_pow_pow a b c d e = + calc (==) { + (a * pow2 b + c * pow2 d) * pow2 e; + (==) { lemma_distr_pow (a * pow2 b) c d e } + a * pow2 b * pow2 e + c * pow2 (d + e); + (==) { Math.Lemmas.paren_mul_right a (pow2 b) (pow2 e); Math.Lemmas.pow2_plus b e } + a * pow2 (b + e) + c * pow2 (d + e); + } + +val lemma_as_nat32_horner (r0 r1 r2 r3:int) : + Lemma (r0 + r1 * pow2 32 + r2 * pow2 64 + r3 * pow2 96 == + ((r3 * pow2 32 + r2) * pow2 32 + r1) * pow2 32 + r0) + +let lemma_as_nat32_horner r0 r1 r2 r3 = + calc (==) { + r0 + pow2 32 * (r1 + pow2 32 * (r2 + pow2 32 * r3)); + (==) { Math.Lemmas.swap_mul (pow2 32) (r1 + pow2 32 * (r2 + pow2 32 * r3)) } + r0 + (r1 + pow2 32 * (r2 + pow2 32 * r3)) * pow2 32; + (==) { Math.Lemmas.swap_mul (pow2 32) (r2 + pow2 32 * r3) } + r0 + (r1 + (r2 + pow2 32 * r3) * pow2 32) * pow2 32; + (==) { lemma_distr_pow r1 (r2 + pow2 32 * r3) 32 32 } + r0 + r1 * pow2 32 + (r2 + pow2 32 * r3) * pow2 64; + (==) { Math.Lemmas.swap_mul (pow2 32) r3 } + r0 + r1 * pow2 32 + (r2 + r3 * pow2 32) * pow2 64; + (==) { lemma_distr_pow r2 r3 32 64 } + r0 + r1 * pow2 32 + r2 * pow2 64 + r3 * pow2 96; + } + +val unfold_nat_from_uint32_four: b:LSeq.lseq uint32 4 -> + Lemma (nat_from_intseq_be b == + v (LSeq.index b 3) + v (LSeq.index b 2) * pow2 32 + + v (LSeq.index b 1) * pow2 64 + v (LSeq.index b 0) * pow2 96) + +let unfold_nat_from_uint32_four b = + let b0 = v (LSeq.index b 0) in + let b1 = v (LSeq.index b 1) in + let b2 = v (LSeq.index b 2) in + let b3 = v (LSeq.index b 3) in + + let res = nat_from_intseq_be b in + nat_from_intseq_be_slice_lemma b 3; + nat_from_intseq_be_lemma0 (Seq.slice b 3 4); + assert (res == b3 + pow2 32 * (nat_from_intseq_be (Seq.slice b 0 3))); + + nat_from_intseq_be_slice_lemma #U32 #SEC #3 (Seq.slice b 0 3) 2; + nat_from_intseq_be_lemma0 (Seq.slice b 2 3); + assert (nat_from_intseq_be (Seq.slice b 0 3) == b2 + pow2 32 * (nat_from_intseq_be (Seq.slice b 0 2))); + + nat_from_intseq_be_slice_lemma #U32 #SEC #2 (Seq.slice b 0 2) 1; + nat_from_intseq_be_lemma0 (Seq.slice b 1 2); + assert (nat_from_intseq_be (Seq.slice b 0 2) == b1 + pow2 32 * (nat_from_intseq_be (Seq.slice b 0 1))); + + nat_from_intseq_be_lemma0 (Seq.slice b 0 1); + assert (res == b3 + pow2 32 * (b2 + pow2 32 * (b1 + pow2 32 * b0))); + lemma_as_nat32_horner b3 b2 b1 b0 + +val vec_set_counter_lemma: + b:LSeq.lseq uint8 16 + -> c:size_nat -> + Lemma ( + let counter = uint #U32 #SEC c in + let n = uints_from_bytes_be #U32 #SEC #4 b in + LSeq.update_sub b 12 4 (nat_to_bytes_be 4 c) == + uints_to_bytes_be (uints_from_bytes_be #U128 #SEC #1 (uints_to_bytes_be (LSeq.upd n 3 counter)))) + +let vec_set_counter_lemma b c = + let counter = uint #U32 #SEC c in + let n = uints_from_bytes_be #U32 #SEC #4 b in + let n' = LSeq.upd n 3 counter in + unfold_nat_from_uint32_four n'; + assert (nat_from_intseq_be n' == + v (LSeq.index n' 3) + v (LSeq.index n' 2) * pow2 32 + + v (LSeq.index n' 1) * pow2 64 + v (LSeq.index n' 0) * pow2 96); + assert (v (LSeq.index n' 3) == c); + assert (LSeq.index n' 2 == LSeq.index n 2); + assert (LSeq.index n' 1 == LSeq.index n 1); + assert (LSeq.index n' 0 == LSeq.index n 0); + index_uints_from_bytes_be #U32 #SEC #4 b 2; + index_uints_from_bytes_be #U32 #SEC #4 b 1; + index_uints_from_bytes_be #U32 #SEC #4 b 0; + lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub b 8 4); + lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub b 4 4); + lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub b 0 4); + assert (nat_from_intseq_be n' == + c + nat_from_bytes_be (LSeq.sub b 8 4) * pow2 32 + + nat_from_bytes_be (LSeq.sub b 4 4) * pow2 64 + nat_from_bytes_be (LSeq.sub b 0 4) * pow2 96); + + let m = LSeq.update_sub b 12 4 (nat_to_bytes_be 4 c) in + let m' = uints_from_bytes_be #U32 #SEC #4 m in + unfold_nat_from_uint32_four m'; + assert (nat_from_intseq_be m' == + v (LSeq.index m' 3) + v (LSeq.index m' 2) * pow2 32 + + v (LSeq.index m' 1) * pow2 64 + v (LSeq.index m' 0) * pow2 96); + index_uints_from_bytes_be #U32 #SEC #4 m 3; + assert (LSeq.index m' 3 == uint_from_bytes_be (LSeq.sub m 12 4)); + lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 12 4); + assert (nat_from_bytes_be (LSeq.sub m 12 4) == uint_v (uint_from_bytes_be #U32 #SEC (LSeq.sub m 12 4))); + assert (LSeq.sub m 12 4 == nat_to_bytes_be 4 c); + lemma_nat_to_from_bytes_be_preserves_value (nat_to_bytes_be #SEC 16 c) 16 c; + assert (v (LSeq.index m' 3) == c); + index_uints_from_bytes_be #U32 #SEC #4 m 2; + index_uints_from_bytes_be #U32 #SEC #4 m 1; + index_uints_from_bytes_be #U32 #SEC #4 m 0; + lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 8 4); + lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 4 4); + lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 0 4); + LSeq.eq_intro (LSeq.sub b 8 4) (LSeq.sub m 8 4); + LSeq.eq_intro (LSeq.sub b 4 4) (LSeq.sub m 4 4); + LSeq.eq_intro (LSeq.sub b 0 4) (LSeq.sub m 0 4); + assert (v (LSeq.index m' 2) == nat_from_bytes_be (LSeq.sub m 8 4)); + assert (v (LSeq.index m' 1) == nat_from_bytes_be (LSeq.sub m 4 4)); + assert (v (LSeq.index m' 0) == nat_from_bytes_be (LSeq.sub m 0 4)); + uints_from_bytes_be_nat_lemma #U32 #SEC #4 m; + assert (nat_from_bytes_be m == + c + nat_from_bytes_be (LSeq.sub b 8 4) * pow2 32 + + nat_from_bytes_be (LSeq.sub b 4 4) * pow2 64 + nat_from_bytes_be (LSeq.sub b 0 4) * pow2 96); + + let u = uints_to_bytes_be (uints_from_bytes_be #U128 #SEC #1 (uints_to_bytes_be n')) in + lemma_nat_from_to_intseq_be_preserves_value #U128 #SEC 1 (uints_from_bytes_be #U128 #SEC #1 (uints_to_bytes_be n')); + assert (u == uints_to_bytes_be (nat_to_intseq_be 1 (nat_from_intseq_be (uints_from_bytes_be #U128 #SEC #1 (uints_to_bytes_be n'))))); + uints_from_bytes_be_nat_lemma #U128 #SEC #1 (uints_to_bytes_be n'); + assert (u == uints_to_bytes_be (nat_to_intseq_be 1 (nat_from_bytes_be (uints_to_bytes_be n')))); + uints_to_bytes_be_nat_lemma #U128 #SEC 1 (nat_from_bytes_be (uints_to_bytes_be n')); + assert (u == nat_to_bytes_be 16 (nat_from_bytes_be (uints_to_bytes_be n'))); + lemma_nat_from_to_bytes_be_preserves_value (uints_to_bytes_be n') 16; + assert (u == uints_to_bytes_be n'); + + assert (nat_from_intseq_be n' == nat_from_bytes_be m); + lemma_nat_from_to_intseq_be_preserves_value #U32 #SEC 4 n'; + assert (nat_to_intseq_be #U32 #SEC 4 (nat_from_intseq_be #U32 #SEC n') == n'); + assert (u == uints_to_bytes_be #U32 #SEC #4 (nat_to_intseq_be #U32 #SEC 4 (nat_from_intseq_be #U32 #SEC n'))); + assert (u == uints_to_bytes_be #U32 #SEC #4 (nat_to_intseq_be #U32 #SEC 4 (nat_from_bytes_be m))); + uints_to_bytes_be_nat_lemma #U32 #SEC 4 (nat_from_bytes_be m); + assert (u == nat_to_bytes_be 16 (nat_from_bytes_be m)); + lemma_nat_from_to_bytes_be_preserves_value m 16; + assert (u == m) inline_for_extraction noextract val load_state: @@ -107,7 +258,32 @@ val load_state: -> counter: uint32 -> Stack unit (requires (fun h -> live h st /\ live h nonce)) - (ensures (fun h0 _ h1 -> modifies1 st h0 h1)) + (ensures (fun h0 _ h1 -> modifies1 st h0 h1 /\ + (let counter0 = Lib.ByteBuffer.uint_to_be counter in + let counter1 = Lib.ByteBuffer.uint_to_be (counter +. u32 1) in + let counter2 = Lib.ByteBuffer.uint_to_be (counter +. u32 2) in + let counter3 = Lib.ByteBuffer.uint_to_be (counter +. u32 3) in + let nonce0 = LSeq.index (as_seq h0 nonce) 0 in + (* nonce <- counter0 *) + let n0_128 = + LSeq.update_sub (uints_to_bytes_be (vec_v nonce0)) 12 4 + (nat_to_bytes_be 4 (v counter0)) in + (* nonce <- counter1 *) + let n1_128 = + LSeq.update_sub (uints_to_bytes_be (vec_v nonce0)) 12 4 + (nat_to_bytes_be 4 (v counter1)) in + (* nonce <- counter2 *) + let n2_128 = + LSeq.update_sub (uints_to_bytes_be (vec_v nonce0)) 12 4 + (nat_to_bytes_be 4 (v counter2)) in + (* nonce <- counter3 *) + let n3_128 = + LSeq.update_sub (uints_to_bytes_be (vec_v nonce0)) 12 4 + (nat_to_bytes_be 4 (v counter3)) in + uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 st) 0)) == n0_128 /\ + uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 st) 1)) == n1_128 /\ + uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 st) 2)) == n2_128 /\ + uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 st) 3)) == n3_128))) let load_state st nonce counter = let counter0 = Lib.ByteBuffer.uint_to_be counter in @@ -118,7 +294,11 @@ let load_state st nonce counter = st.(size 0) <- cast U128 1 (vec_set #U32 #4 (cast #U128 #1 U32 4 nonce0) 3ul counter0); st.(size 1) <- cast U128 1 (vec_set #U32 #4 (cast #U128 #1 U32 4 nonce0) 3ul counter1); st.(size 2) <- cast U128 1 (vec_set #U32 #4 (cast #U128 #1 U32 4 nonce0) 3ul counter2); - st.(size 3) <- cast U128 1 (vec_set #U32 #4 (cast #U128 #1 U32 4 nonce0) 3ul counter3) + st.(size 3) <- cast U128 1 (vec_set #U32 #4 (cast #U128 #1 U32 4 nonce0) 3ul counter3); + vec_set_counter_lemma (uints_to_bytes_be (vec_v nonce0)) (v counter0); + vec_set_counter_lemma (uints_to_bytes_be (vec_v nonce0)) (v counter1); + vec_set_counter_lemma (uints_to_bytes_be (vec_v nonce0)) (v counter2); + vec_set_counter_lemma (uints_to_bytes_be (vec_v nonce0)) (v counter3) inline_for_extraction noextract val store_block0: @@ -126,7 +306,8 @@ val store_block0: -> st: state -> Stack unit (requires (fun h -> live h st /\ live h out)) - (ensures (fun h0 _ h1 -> modifies1 out h0 h1)) + (ensures (fun h0 _ h1 -> modifies1 out h0 h1 /\ + as_seq h1 out == uints_to_bytes_le (vec_v (LSeq.index (as_seq h0 st) 0)))) let store_block0 out st = vec_store_le out st.(size 0) diff --git a/lib/Lib.IntVector.fst b/lib/Lib.IntVector.fst index d01f89ad16..053b20b650 100644 --- a/lib/Lib.IntVector.fst +++ b/lib/Lib.IntVector.fst @@ -529,3 +529,5 @@ let vec_permute4_lemma #t v i1 i2 i3 i4 = () let vec_permute8 #t v i1 i2 i3 i4 i5 i6 i7 i8 = admit() let vec_permute16 #t = admit() let vec_permute32 #t = admit() + +let cast_lemma #t #w t' w' v = admit() diff --git a/lib/Lib.IntVector.fsti b/lib/Lib.IntVector.fsti index 665118bb8d..4170117559 100644 --- a/lib/Lib.IntVector.fsti +++ b/lib/Lib.IntVector.fsti @@ -476,3 +476,7 @@ val vec_permute32: #t:v_inttype -> v1:vec_t t 32 vv1.[v i21] vv1.[v i22] vv1.[v i23] vv1.[v i24] vv1.[v i25] vv1.[v i26] vv1.[v i27] vv1.[v i28] vv1.[v i29] vv1.[v i30] vv1.[v i31] vv1.[v i32]} + +val cast_lemma: #t:v_inttype -> #w:width -> t':v_inttype -> w':width{bits t * w == bits t' * w'} -> v:vec_t t w -> + Lemma (cast t' w' v == vec_from_bytes_be t' w' (vec_to_bytes_be v)) + [SMTPat (cast t' w' v)] diff --git a/specs/Spec.AES.fst b/specs/Spec.AES.fst index 0dca291821..e7857f81d8 100644 --- a/specs/Spec.AES.fst +++ b/specs/Spec.AES.fst @@ -329,8 +329,8 @@ let aes_decrypt_block (v:variant) (key:aes_xkey v) (input:block) : Tot block = let aes_ctr_key_block (v:variant) (k:aes_xkey v) (n:lbytes 12) (c:size_nat) : Tot block = let ctrby = nat_to_bytes_be 4 c in let input = create 16 (u8 0) in - let input = repeati #(lbytes 16) 12 (fun i b -> b.[i] <- n.[i]) input in - let input = repeati #(lbytes 16) 4 (fun i b -> b.[12+i] <- (Seq.index ctrby i)) input in + let input = update_sub input 0 12 n in + let input = update_sub input 12 4 ctrby in aes_encrypt_block v k input noeq type aes_ctr_state (v:variant) = {