Skip to content

Commit

Permalink
Update Hacl.Impl.AES.CoreNI.fst
Browse files Browse the repository at this point in the history
  • Loading branch information
mamonet committed Feb 3, 2024
1 parent 489b3f5 commit b514576
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 10 deletions.
2 changes: 1 addition & 1 deletion code/aes/Hacl.Impl.AES.Core.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ val load_block0:
-> st: state m
-> b: lbuffer uint8 16ul ->
Stack unit
(requires (fun h -> live h st /\ live h b))
(requires (fun h -> live h st /\ live h b /\ disjoint st b))
(ensures (fun h0 _ h1 -> modifies1 st h0 h1))

inline_for_extraction noextract
Expand Down
195 changes: 188 additions & 7 deletions code/aes/Hacl.Impl.AES.CoreNI.fst
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
module Hacl.Impl.AES.CoreNI

open FStar.Mul
open FStar.HyperStack
open FStar.HyperStack.All
open Lib.IntTypes
open Lib.Buffer
open Lib.IntVector
open Lib.ByteSequence

module BSeq = Lib.ByteSequence
module ST = FStar.HyperStack.ST
module LSeq = Lib.Sequence

#set-options "--z3rlimit 30 --max_fuel 0 --max_ifuel 0"

inline_for_extraction noextract
let vec128 = vec_t U128 1
inline_for_extraction noextract
Expand Down Expand Up @@ -55,7 +58,7 @@ val load_block0:
(requires (fun h -> live h st /\ live h b /\ disjoint st b))
(ensures (fun h0 _ h1 -> modifies1 st h0 h1 /\
(let st' = as_seq h1 st in
vec_v (LSeq.index st' 0) == BSeq.uints_from_bytes_le (as_seq h0 b) /\
vec_v (LSeq.index st' 0) == uints_from_bytes_le (as_seq h0 b) /\
LSeq.index st' 1 == LSeq.index (as_seq h0 st) 1 /\
LSeq.index st' 2 == LSeq.index (as_seq h0 st) 2 /\
LSeq.index st' 3 == LSeq.index (as_seq h0 st) 3)))
Expand All @@ -70,7 +73,7 @@ val load_key1:
Stack unit
(requires (fun h -> live h k /\ live h b))
(ensures (fun h0 _ h1 -> modifies1 k h0 h1 /\
vec_v (LSeq.index (as_seq h1 k) 0) == BSeq.uints_from_bytes_le (as_seq h0 b)))
vec_v (LSeq.index (as_seq h1 k) 0) == uints_from_bytes_le (as_seq h0 b)))

let load_key1 k b = k.(size 0) <- vec_load_le U128 1 b

Expand All @@ -83,7 +86,7 @@ val load_nonce:
(requires (fun h -> live h n /\ live h b))
(ensures (fun h0 _ h1 -> modifies1 n h0 h1 /\
vec_v (LSeq.index (as_seq h1 n) 0) ==
BSeq.uints_from_bytes_le (LSeq.concat (as_seq h0 b) (LSeq.create 4 (u8 0)))))
uints_from_bytes_le (LSeq.concat (as_seq h0 b) (LSeq.create 4 (u8 0)))))

let load_nonce n b =
let h0 = ST.get() in
Expand All @@ -99,6 +102,154 @@ let load_nonce n b =
n.(size 0) <- vec_load_le U128 1 nb;
pop_frame()

val lemma_distr_pow (a b:int) (c d:nat) : Lemma ((a + b * pow2 c) * pow2 d = a * pow2 d + b * pow2 (c + d))
let lemma_distr_pow a b c d =
calc (==) {
(a + b * pow2 c) * pow2 d;
(==) { Math.Lemmas.distributivity_add_left a (b * pow2 c) (pow2 d) }
a * pow2 d + b * pow2 c * pow2 d;
(==) { Math.Lemmas.paren_mul_right b (pow2 c) (pow2 d); Math.Lemmas.pow2_plus c d }
a * pow2 d + b * pow2 (c + d);
}

val lemma_distr_pow_pow (a:int) (b:nat) (c:int) (d e:nat) :
Lemma ((a * pow2 b + c * pow2 d) * pow2 e = a * pow2 (b + e) + c * pow2 (d + e))
let lemma_distr_pow_pow a b c d e =
calc (==) {
(a * pow2 b + c * pow2 d) * pow2 e;
(==) { lemma_distr_pow (a * pow2 b) c d e }
a * pow2 b * pow2 e + c * pow2 (d + e);
(==) { Math.Lemmas.paren_mul_right a (pow2 b) (pow2 e); Math.Lemmas.pow2_plus b e }
a * pow2 (b + e) + c * pow2 (d + e);
}

val lemma_as_nat32_horner (r0 r1 r2 r3:int) :
Lemma (r0 + r1 * pow2 32 + r2 * pow2 64 + r3 * pow2 96 ==
((r3 * pow2 32 + r2) * pow2 32 + r1) * pow2 32 + r0)

let lemma_as_nat32_horner r0 r1 r2 r3 =
calc (==) {
r0 + pow2 32 * (r1 + pow2 32 * (r2 + pow2 32 * r3));
(==) { Math.Lemmas.swap_mul (pow2 32) (r1 + pow2 32 * (r2 + pow2 32 * r3)) }
r0 + (r1 + pow2 32 * (r2 + pow2 32 * r3)) * pow2 32;
(==) { Math.Lemmas.swap_mul (pow2 32) (r2 + pow2 32 * r3) }
r0 + (r1 + (r2 + pow2 32 * r3) * pow2 32) * pow2 32;
(==) { lemma_distr_pow r1 (r2 + pow2 32 * r3) 32 32 }
r0 + r1 * pow2 32 + (r2 + pow2 32 * r3) * pow2 64;
(==) { Math.Lemmas.swap_mul (pow2 32) r3 }
r0 + r1 * pow2 32 + (r2 + r3 * pow2 32) * pow2 64;
(==) { lemma_distr_pow r2 r3 32 64 }
r0 + r1 * pow2 32 + r2 * pow2 64 + r3 * pow2 96;
}

val unfold_nat_from_uint32_four: b:LSeq.lseq uint32 4 ->
Lemma (nat_from_intseq_be b ==
v (LSeq.index b 3) + v (LSeq.index b 2) * pow2 32 +
v (LSeq.index b 1) * pow2 64 + v (LSeq.index b 0) * pow2 96)

let unfold_nat_from_uint32_four b =
let b0 = v (LSeq.index b 0) in
let b1 = v (LSeq.index b 1) in
let b2 = v (LSeq.index b 2) in
let b3 = v (LSeq.index b 3) in

let res = nat_from_intseq_be b in
nat_from_intseq_be_slice_lemma b 3;
nat_from_intseq_be_lemma0 (Seq.slice b 3 4);
assert (res == b3 + pow2 32 * (nat_from_intseq_be (Seq.slice b 0 3)));

nat_from_intseq_be_slice_lemma #U32 #SEC #3 (Seq.slice b 0 3) 2;
nat_from_intseq_be_lemma0 (Seq.slice b 2 3);
assert (nat_from_intseq_be (Seq.slice b 0 3) == b2 + pow2 32 * (nat_from_intseq_be (Seq.slice b 0 2)));

nat_from_intseq_be_slice_lemma #U32 #SEC #2 (Seq.slice b 0 2) 1;
nat_from_intseq_be_lemma0 (Seq.slice b 1 2);
assert (nat_from_intseq_be (Seq.slice b 0 2) == b1 + pow2 32 * (nat_from_intseq_be (Seq.slice b 0 1)));

nat_from_intseq_be_lemma0 (Seq.slice b 0 1);
assert (res == b3 + pow2 32 * (b2 + pow2 32 * (b1 + pow2 32 * b0)));
lemma_as_nat32_horner b3 b2 b1 b0

val vec_set_counter_lemma:
b:LSeq.lseq uint8 16
-> c:size_nat ->
Lemma (
let counter = uint #U32 #SEC c in
let n = uints_from_bytes_be #U32 #SEC #4 b in
LSeq.update_sub b 12 4 (nat_to_bytes_be 4 c) ==
uints_to_bytes_be (uints_from_bytes_be #U128 #SEC #1 (uints_to_bytes_be (LSeq.upd n 3 counter))))

let vec_set_counter_lemma b c =
let counter = uint #U32 #SEC c in
let n = uints_from_bytes_be #U32 #SEC #4 b in
let n' = LSeq.upd n 3 counter in
unfold_nat_from_uint32_four n';
assert (nat_from_intseq_be n' ==
v (LSeq.index n' 3) + v (LSeq.index n' 2) * pow2 32 +
v (LSeq.index n' 1) * pow2 64 + v (LSeq.index n' 0) * pow2 96);
assert (v (LSeq.index n' 3) == c);
assert (LSeq.index n' 2 == LSeq.index n 2);
assert (LSeq.index n' 1 == LSeq.index n 1);
assert (LSeq.index n' 0 == LSeq.index n 0);
index_uints_from_bytes_be #U32 #SEC #4 b 2;
index_uints_from_bytes_be #U32 #SEC #4 b 1;
index_uints_from_bytes_be #U32 #SEC #4 b 0;
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub b 8 4);
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub b 4 4);
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub b 0 4);
assert (nat_from_intseq_be n' ==
c + nat_from_bytes_be (LSeq.sub b 8 4) * pow2 32 +
nat_from_bytes_be (LSeq.sub b 4 4) * pow2 64 + nat_from_bytes_be (LSeq.sub b 0 4) * pow2 96);

let m = LSeq.update_sub b 12 4 (nat_to_bytes_be 4 c) in
let m' = uints_from_bytes_be #U32 #SEC #4 m in
unfold_nat_from_uint32_four m';
assert (nat_from_intseq_be m' ==
v (LSeq.index m' 3) + v (LSeq.index m' 2) * pow2 32 +
v (LSeq.index m' 1) * pow2 64 + v (LSeq.index m' 0) * pow2 96);
index_uints_from_bytes_be #U32 #SEC #4 m 3;
assert (LSeq.index m' 3 == uint_from_bytes_be (LSeq.sub m 12 4));
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 12 4);
assert (nat_from_bytes_be (LSeq.sub m 12 4) == uint_v (uint_from_bytes_be #U32 #SEC (LSeq.sub m 12 4)));
assert (LSeq.sub m 12 4 == nat_to_bytes_be 4 c);
lemma_nat_to_from_bytes_be_preserves_value (nat_to_bytes_be #SEC 16 c) 16 c;
assert (v (LSeq.index m' 3) == c);
index_uints_from_bytes_be #U32 #SEC #4 m 2;
index_uints_from_bytes_be #U32 #SEC #4 m 1;
index_uints_from_bytes_be #U32 #SEC #4 m 0;
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 8 4);
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 4 4);
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 0 4);
LSeq.eq_intro (LSeq.sub b 8 4) (LSeq.sub m 8 4);
LSeq.eq_intro (LSeq.sub b 4 4) (LSeq.sub m 4 4);
LSeq.eq_intro (LSeq.sub b 0 4) (LSeq.sub m 0 4);
assert (v (LSeq.index m' 2) == nat_from_bytes_be (LSeq.sub m 8 4));
assert (v (LSeq.index m' 1) == nat_from_bytes_be (LSeq.sub m 4 4));
assert (v (LSeq.index m' 0) == nat_from_bytes_be (LSeq.sub m 0 4));
uints_from_bytes_be_nat_lemma #U32 #SEC #4 m;
assert (nat_from_bytes_be m ==
c + nat_from_bytes_be (LSeq.sub b 8 4) * pow2 32 +
nat_from_bytes_be (LSeq.sub b 4 4) * pow2 64 + nat_from_bytes_be (LSeq.sub b 0 4) * pow2 96);

let u = uints_to_bytes_be (uints_from_bytes_be #U128 #SEC #1 (uints_to_bytes_be n')) in
lemma_nat_from_to_intseq_be_preserves_value #U128 #SEC 1 (uints_from_bytes_be #U128 #SEC #1 (uints_to_bytes_be n'));
assert (u == uints_to_bytes_be (nat_to_intseq_be 1 (nat_from_intseq_be (uints_from_bytes_be #U128 #SEC #1 (uints_to_bytes_be n')))));
uints_from_bytes_be_nat_lemma #U128 #SEC #1 (uints_to_bytes_be n');
assert (u == uints_to_bytes_be (nat_to_intseq_be 1 (nat_from_bytes_be (uints_to_bytes_be n'))));
uints_to_bytes_be_nat_lemma #U128 #SEC 1 (nat_from_bytes_be (uints_to_bytes_be n'));
assert (u == nat_to_bytes_be 16 (nat_from_bytes_be (uints_to_bytes_be n')));
lemma_nat_from_to_bytes_be_preserves_value (uints_to_bytes_be n') 16;
assert (u == uints_to_bytes_be n');

assert (nat_from_intseq_be n' == nat_from_bytes_be m);
lemma_nat_from_to_intseq_be_preserves_value #U32 #SEC 4 n';
assert (nat_to_intseq_be #U32 #SEC 4 (nat_from_intseq_be #U32 #SEC n') == n');
assert (u == uints_to_bytes_be #U32 #SEC #4 (nat_to_intseq_be #U32 #SEC 4 (nat_from_intseq_be #U32 #SEC n')));
assert (u == uints_to_bytes_be #U32 #SEC #4 (nat_to_intseq_be #U32 #SEC 4 (nat_from_bytes_be m)));
uints_to_bytes_be_nat_lemma #U32 #SEC 4 (nat_from_bytes_be m);
assert (u == nat_to_bytes_be 16 (nat_from_bytes_be m));
lemma_nat_from_to_bytes_be_preserves_value m 16;
assert (u == m)

inline_for_extraction noextract
val load_state:
Expand All @@ -107,7 +258,32 @@ val load_state:
-> counter: uint32 ->
Stack unit
(requires (fun h -> live h st /\ live h nonce))
(ensures (fun h0 _ h1 -> modifies1 st h0 h1))
(ensures (fun h0 _ h1 -> modifies1 st h0 h1 /\
(let counter0 = Lib.ByteBuffer.uint_to_be counter in
let counter1 = Lib.ByteBuffer.uint_to_be (counter +. u32 1) in
let counter2 = Lib.ByteBuffer.uint_to_be (counter +. u32 2) in
let counter3 = Lib.ByteBuffer.uint_to_be (counter +. u32 3) in
let nonce0 = LSeq.index (as_seq h0 nonce) 0 in
(* nonce <- counter0 *)
let n0_128 =
LSeq.update_sub (uints_to_bytes_be (vec_v nonce0)) 12 4
(nat_to_bytes_be 4 (v counter0)) in
(* nonce <- counter1 *)
let n1_128 =
LSeq.update_sub (uints_to_bytes_be (vec_v nonce0)) 12 4
(nat_to_bytes_be 4 (v counter1)) in
(* nonce <- counter2 *)
let n2_128 =
LSeq.update_sub (uints_to_bytes_be (vec_v nonce0)) 12 4
(nat_to_bytes_be 4 (v counter2)) in
(* nonce <- counter3 *)
let n3_128 =
LSeq.update_sub (uints_to_bytes_be (vec_v nonce0)) 12 4
(nat_to_bytes_be 4 (v counter3)) in
uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 st) 0)) == n0_128 /\
uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 st) 1)) == n1_128 /\
uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 st) 2)) == n2_128 /\
uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 st) 3)) == n3_128)))

let load_state st nonce counter =
let counter0 = Lib.ByteBuffer.uint_to_be counter in
Expand All @@ -118,15 +294,20 @@ let load_state st nonce counter =
st.(size 0) <- cast U128 1 (vec_set #U32 #4 (cast #U128 #1 U32 4 nonce0) 3ul counter0);
st.(size 1) <- cast U128 1 (vec_set #U32 #4 (cast #U128 #1 U32 4 nonce0) 3ul counter1);
st.(size 2) <- cast U128 1 (vec_set #U32 #4 (cast #U128 #1 U32 4 nonce0) 3ul counter2);
st.(size 3) <- cast U128 1 (vec_set #U32 #4 (cast #U128 #1 U32 4 nonce0) 3ul counter3)
st.(size 3) <- cast U128 1 (vec_set #U32 #4 (cast #U128 #1 U32 4 nonce0) 3ul counter3);
vec_set_counter_lemma (uints_to_bytes_be (vec_v nonce0)) (v counter0);
vec_set_counter_lemma (uints_to_bytes_be (vec_v nonce0)) (v counter1);
vec_set_counter_lemma (uints_to_bytes_be (vec_v nonce0)) (v counter2);
vec_set_counter_lemma (uints_to_bytes_be (vec_v nonce0)) (v counter3)

inline_for_extraction noextract
val store_block0:
out: lbuffer uint8 16ul
-> st: state ->
Stack unit
(requires (fun h -> live h st /\ live h out))
(ensures (fun h0 _ h1 -> modifies1 out h0 h1))
(ensures (fun h0 _ h1 -> modifies1 out h0 h1 /\
as_seq h1 out == uints_to_bytes_le (vec_v (LSeq.index (as_seq h0 st) 0))))

let store_block0 out st =
vec_store_le out st.(size 0)
Expand Down
2 changes: 2 additions & 0 deletions lib/Lib.IntVector.fst
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,5 @@ let vec_permute4_lemma #t v i1 i2 i3 i4 = ()
let vec_permute8 #t v i1 i2 i3 i4 i5 i6 i7 i8 = admit()
let vec_permute16 #t = admit()
let vec_permute32 #t = admit()

let cast_lemma #t #w t' w' v = admit()
4 changes: 4 additions & 0 deletions lib/Lib.IntVector.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,7 @@ val vec_permute32: #t:v_inttype -> v1:vec_t t 32
vv1.[v i21] vv1.[v i22] vv1.[v i23] vv1.[v i24]
vv1.[v i25] vv1.[v i26] vv1.[v i27] vv1.[v i28]
vv1.[v i29] vv1.[v i30] vv1.[v i31] vv1.[v i32]}

val cast_lemma: #t:v_inttype -> #w:width -> t':v_inttype -> w':width{bits t * w == bits t' * w'} -> v:vec_t t w ->
Lemma (cast t' w' v == vec_from_bytes_be t' w' (vec_to_bytes_be v))
[SMTPat (cast t' w' v)]
4 changes: 2 additions & 2 deletions specs/Spec.AES.fst
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ let aes_decrypt_block (v:variant) (key:aes_xkey v) (input:block) : Tot block =
let aes_ctr_key_block (v:variant) (k:aes_xkey v) (n:lbytes 12) (c:size_nat) : Tot block =
let ctrby = nat_to_bytes_be 4 c in
let input = create 16 (u8 0) in
let input = repeati #(lbytes 16) 12 (fun i b -> b.[i] <- n.[i]) input in
let input = repeati #(lbytes 16) 4 (fun i b -> b.[12+i] <- (Seq.index ctrby i)) input in
let input = update_sub input 0 12 n in
let input = update_sub input 12 4 ctrby in
aes_encrypt_block v k input

noeq type aes_ctr_state (v:variant) = {
Expand Down

0 comments on commit b514576

Please sign in to comment.