Skip to content

Commit

Permalink
Verify key_expansion128
Browse files Browse the repository at this point in the history
  • Loading branch information
mamonet committed Feb 21, 2024
1 parent e9e91a2 commit bcd96f0
Show file tree
Hide file tree
Showing 10 changed files with 478 additions and 101 deletions.
1 change: 1 addition & 0 deletions code/aes/Hacl.AES.Common.fst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Hacl.AES.Common
open Lib.IntTypes
open Lib.Sequence

unfold noextract
let u8_16_to_le (x:lseq uint8 16): lseq uint8 16 =
create16 x.[15] x.[14] x.[13] x.[12] x.[11] x.[10] x.[9]
x.[8] x.[7] x.[6] x.[5] x.[4] x.[3] x.[2] x.[1] x.[0]
126 changes: 126 additions & 0 deletions code/aes/Hacl.AES.CoreNI.Lemmas.fst
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,132 @@ let vec_shift_l_32_lemma p b =
(LSeq.update_sub p 0 12 (LSeq.sub x 4 12));
LSeq.eq_intro (shift_l_32_s x) (LSeq.update_sub p 0 12 (LSeq.sub x 4 12))

let vec_create4_3_lemma b =
let m0 = LSeq.update_sub b 8 4 (LSeq.sub b 12 4) in
let m = LSeq.update_sub m0 0 8 (LSeq.sub m0 8 8) in
let m' = uints_from_bytes_be #U32 #SEC #4 m in
unfold_nat_from_uint32_four m';
assert (nat_from_intseq_be m' ==
v (LSeq.index m' 3) + pow2 32 * (v (LSeq.index m' 2) +
pow2 32 * (v (LSeq.index m' 1) + pow2 32 * v (LSeq.index m' 0))));
index_uints_from_bytes_be #U32 #SEC #4 m 3;
index_uints_from_bytes_be #U32 #SEC #4 m 2;
index_uints_from_bytes_be #U32 #SEC #4 m 1;
index_uints_from_bytes_be #U32 #SEC #4 m 0;
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 12 4);
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 8 4);
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 4 4);
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 0 4);
assert (v (LSeq.index m' 3) == nat_from_bytes_be (LSeq.sub m 12 4));
assert (v (LSeq.index m' 2) == nat_from_bytes_be (LSeq.sub m 8 4));
assert (v (LSeq.index m' 1) == nat_from_bytes_be (LSeq.sub m 4 4));
assert (v (LSeq.index m' 0) == nat_from_bytes_be (LSeq.sub m 0 4));
LSeq.eq_intro (LSeq.sub m 12 4) (LSeq.sub b 12 4);
LSeq.eq_intro (LSeq.sub m 8 4) (LSeq.sub b 12 4);
LSeq.lemma_concat2 4 (LSeq.sub (LSeq.sub m 0 8) 0 4)
4 (LSeq.sub (LSeq.sub m 0 8) 4 4) (LSeq.sub m 0 8);
LSeq.eq_intro (LSeq.sub m 4 4) (LSeq.sub b 12 4);
LSeq.eq_intro (LSeq.sub m 0 4) (LSeq.sub b 12 4);
uints_from_bytes_be_nat_lemma #U32 #SEC #4 m;
assert (nat_from_bytes_be m ==
nat_from_bytes_be (LSeq.sub b 12 4) +
pow2 32 * (nat_from_bytes_be (LSeq.sub b 12 4) +
pow2 32 * (nat_from_bytes_be (LSeq.sub b 12 4) +
pow2 32 * nat_from_bytes_be (LSeq.sub b 12 4))));

let n = uints_from_bytes_be #U32 #SEC #4 b in
let i = LSeq.index n 3 in
let n' = LSeq.create4 i i i i in
unfold_nat_from_uint32_four n';
assert (nat_from_intseq_be n' ==
v (LSeq.index n' 3) + pow2 32 * (v (LSeq.index n' 2) +
pow2 32 * (v (LSeq.index n' 1) + pow2 32 * v (LSeq.index n' 0))));
assert (LSeq.index n' 3 == LSeq.index n 3);
assert (LSeq.index n' 2 == LSeq.index n 3);
assert (LSeq.index n' 1 == LSeq.index n 3);
assert (LSeq.index n' 0 == LSeq.index n 3);
index_uints_from_bytes_be #U32 #SEC #4 b 3;
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub b 12 4);
assert (nat_from_intseq_be n' ==
nat_from_bytes_be (LSeq.sub b 12 4) +
pow2 32 * (nat_from_bytes_be (LSeq.sub b 12 4) +
pow2 32 * (nat_from_bytes_be (LSeq.sub b 12 4) +
pow2 32 * nat_from_bytes_be (LSeq.sub b 12 4))));

let u = uints_to_bytes_be n' in
assert (nat_from_intseq_be n' == nat_from_bytes_be m);
lemma_nat_from_to_intseq_be_preserves_value #U32 #SEC 4 n';
assert (nat_to_intseq_be #U32 #SEC 4 (nat_from_intseq_be #U32 #SEC n') == n');
assert (u == uints_to_bytes_be #U32 #SEC #4 (nat_to_intseq_be #U32 #SEC 4 (nat_from_intseq_be #U32 #SEC n')));
assert (u == uints_to_bytes_be #U32 #SEC #4 (nat_to_intseq_be #U32 #SEC 4 (nat_from_bytes_be m)));
uints_to_bytes_be_nat_lemma #U32 #SEC 4 (nat_from_bytes_be m);
assert (u == nat_to_bytes_be 16 (nat_from_bytes_be m));
lemma_nat_from_to_bytes_be_preserves_value m 16;
assert (u == m)

let vec_create4_2_lemma b =
let m0 = LSeq.update_sub b 12 4 (LSeq.sub b 8 4) in
let m = LSeq.update_sub m0 0 8 (LSeq.sub m0 8 8) in
let m' = uints_from_bytes_be #U32 #SEC #4 m in
unfold_nat_from_uint32_four m';
assert (nat_from_intseq_be m' ==
v (LSeq.index m' 3) + pow2 32 * (v (LSeq.index m' 2) +
pow2 32 * (v (LSeq.index m' 1) + pow2 32 * v (LSeq.index m' 0))));
index_uints_from_bytes_be #U32 #SEC #4 m 3;
index_uints_from_bytes_be #U32 #SEC #4 m 2;
index_uints_from_bytes_be #U32 #SEC #4 m 1;
index_uints_from_bytes_be #U32 #SEC #4 m 0;
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 12 4);
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 8 4);
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 4 4);
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub m 0 4);
assert (v (LSeq.index m' 3) == nat_from_bytes_be (LSeq.sub m 12 4));
assert (v (LSeq.index m' 2) == nat_from_bytes_be (LSeq.sub m 8 4));
assert (v (LSeq.index m' 1) == nat_from_bytes_be (LSeq.sub m 4 4));
assert (v (LSeq.index m' 0) == nat_from_bytes_be (LSeq.sub m 0 4));
LSeq.eq_intro (LSeq.sub m 12 4) (LSeq.sub b 8 4);
LSeq.eq_intro (LSeq.sub m 8 4) (LSeq.sub b 8 4);
LSeq.lemma_concat2 4 (LSeq.sub (LSeq.sub m 0 8) 0 4)
4 (LSeq.sub (LSeq.sub m 0 8) 4 4) (LSeq.sub m 0 8);
LSeq.eq_intro (LSeq.sub m 4 4) (LSeq.sub b 8 4);
LSeq.eq_intro (LSeq.sub m 0 4) (LSeq.sub b 8 4);
uints_from_bytes_be_nat_lemma #U32 #SEC #4 m;
assert (nat_from_bytes_be m ==
nat_from_bytes_be (LSeq.sub b 8 4) +
pow2 32 * (nat_from_bytes_be (LSeq.sub b 8 4) +
pow2 32 * (nat_from_bytes_be (LSeq.sub b 8 4) +
pow2 32 * nat_from_bytes_be (LSeq.sub b 8 4))));

let n = uints_from_bytes_be #U32 #SEC #4 b in
let i = LSeq.index n 2 in
let n' = LSeq.create4 i i i i in
unfold_nat_from_uint32_four n';
assert (nat_from_intseq_be n' ==
v (LSeq.index n' 3) + pow2 32 * (v (LSeq.index n' 2) +
pow2 32 * (v (LSeq.index n' 1) + pow2 32 * v (LSeq.index n' 0))));
assert (LSeq.index n' 3 == LSeq.index n 2);
assert (LSeq.index n' 2 == LSeq.index n 2);
assert (LSeq.index n' 1 == LSeq.index n 2);
assert (LSeq.index n' 0 == LSeq.index n 2);
index_uints_from_bytes_be #U32 #SEC #4 b 2;
lemma_reveal_uint_to_bytes_be #U32 #SEC (LSeq.sub b 8 4);
assert (nat_from_intseq_be n' ==
nat_from_bytes_be (LSeq.sub b 8 4) +
pow2 32 * (nat_from_bytes_be (LSeq.sub b 8 4) +
pow2 32 * (nat_from_bytes_be (LSeq.sub b 8 4) +
pow2 32 * nat_from_bytes_be (LSeq.sub b 8 4))));

let u = uints_to_bytes_be n' in
assert (nat_from_intseq_be n' == nat_from_bytes_be m);
lemma_nat_from_to_intseq_be_preserves_value #U32 #SEC 4 n';
assert (nat_to_intseq_be #U32 #SEC 4 (nat_from_intseq_be #U32 #SEC n') == n');
assert (u == uints_to_bytes_be #U32 #SEC #4 (nat_to_intseq_be #U32 #SEC 4 (nat_from_intseq_be #U32 #SEC n')));
assert (u == uints_to_bytes_be #U32 #SEC #4 (nat_to_intseq_be #U32 #SEC 4 (nat_from_bytes_be m)));
uints_to_bytes_be_nat_lemma #U32 #SEC 4 (nat_from_bytes_be m);
assert (u == nat_to_bytes_be 16 (nat_from_bytes_be m));
lemma_nat_from_to_bytes_be_preserves_value m 16;
assert (u == m)

let uints_bytes_le_lemma b =
lemma_nat_from_to_intseq_le_preserves_value 1 (uints_from_bytes_le #U128 #SEC #1 b);
assert (nat_to_intseq_le 1 (nat_from_intseq_le (uints_from_bytes_le #U128 #SEC #1 b)) ==
Expand Down
16 changes: 16 additions & 0 deletions code/aes/Hacl.AES.CoreNI.Lemmas.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ val vec_shift_l_32_lemma:
LSeq.update_sub p 0 12 (LSeq.sub (uints_to_bytes_be b) 4 12) ==
uints_to_bytes_be (LSeq.map (shift_left_i (size 32)) b))

val vec_create4_3_lemma:
b:LSeq.lseq uint8 16 ->
Lemma (
let i = LSeq.index (uints_from_bytes_be #U32 #SEC #4 b) 3 in
let b0 = LSeq.update_sub b 8 4 (LSeq.sub b 12 4) in
uints_to_bytes_be (LSeq.create4 i i i i) ==
LSeq.update_sub b0 0 8 (LSeq.sub b0 8 8))

val vec_create4_2_lemma:
b:LSeq.lseq uint8 16 ->
Lemma (
let i = LSeq.index (uints_from_bytes_be #U32 #SEC #4 b) 2 in
let b0 = LSeq.update_sub b 12 4 (LSeq.sub b 8 4) in
uints_to_bytes_be (LSeq.create4 i i i i) ==
LSeq.update_sub b0 0 8 (LSeq.sub b0 8 8))

val uints_bytes_le_lemma:
b:LSeq.lseq uint8 16 ->
Lemma (
Expand Down
6 changes: 6 additions & 0 deletions code/aes/Hacl.AES.Generic.Lemmas.fst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,9 @@ let keys_to_bytes_lemma m a b i =
match m with
| MAES -> keys_to_bytes_ni_lemma m a b i
| M32 -> (assume (sub (keys_to_bytes m a b) (i*16) 16 == key_to_bytes m (sub b (i*(v (klen m))) (v (klen m)))))

let keyx_to_bytes_lemma m a b i =
admit()

let u8_16_to_le_lemma x =
eq_intro (Hacl.AES.Common.u8_16_to_le (Hacl.AES.Common.u8_16_to_le x)) x
11 changes: 11 additions & 0 deletions code/aes/Hacl.AES.Generic.Lemmas.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,14 @@ val keys_to_bytes_lemma:
-> b:lseq (stelem m) ((num_rounds a-1) * (v (klen m)))
-> i:nat{i < (num_rounds a-1)} ->
Lemma (sub (keys_to_bytes m a b) (i*16) 16 == key_to_bytes m (sub b (i * v (klen m)) (v (klen m))))

val keyx_to_bytes_lemma:
m:m_spec{m == MAES \/ m == M32}
-> a: variant
-> b:lseq (stelem m) ((num_rounds a+1) * (v (klen m)))
-> i:nat{i < (num_rounds a+1)} ->
Lemma (sub (keyx_to_bytes m a b) (i*16) 16 == key_to_bytes m (sub b (v (klen m) * i) (v (klen m))))

val u8_16_to_le_lemma:
x:lseq uint8 16 ->
Lemma (Hacl.AES.Common.u8_16_to_le (Hacl.AES.Common.u8_16_to_le x) == x)
16 changes: 4 additions & 12 deletions code/aes/Hacl.Impl.AES.Core.fst
Original file line number Diff line number Diff line change
Expand Up @@ -131,27 +131,19 @@ inline_for_extraction noextract
let aes_keygen_assist0 #m n p rcon =
match m with
| MAES -> Hacl.Impl.AES.CoreNI.aes_keygen_assist0 n p rcon
| M32 -> let h0 = ST.get() in Hacl.Impl.AES.CoreBitSlice.aes_keygen_assist0 n p rcon; let h1 = ST.get() in assume (let k = Spec.AES.aes_keygen_assist rcon (key_to_bytes m (as_seq h0 p)) in
let i = LSeq.index (uints_from_bytes_be #U32 #SEC #4 k) 3 in
key_to_bytes m (as_seq h1 n) == uints_to_bytes_be (LSeq.create4 i i i i))
| M32 -> let h0 = ST.get() in Hacl.Impl.AES.CoreBitSlice.aes_keygen_assist0 n p rcon; let h1 = ST.get() in assume (key_to_bytes m (as_seq h1 n) == Spec.AES.keygen_assist0 rcon (key_to_bytes m (as_seq h0 p)))

inline_for_extraction noextract
let aes_keygen_assist1 #m n p =
match m with
| MAES -> Hacl.Impl.AES.CoreNI.aes_keygen_assist1 n p
| M32 -> let h0 = ST.get() in Hacl.Impl.AES.CoreBitSlice.aes_keygen_assist1 n p; let h1 = ST.get() in assume (let k = Spec.AES.aes_keygen_assist (u8 0) (key_to_bytes m (as_seq h0 p)) in
let i = LSeq.index (uints_from_bytes_be #U32 #SEC #4 k) 2 in
key_to_bytes m (as_seq h1 n) == uints_to_bytes_be (LSeq.create4 i i i i))
| M32 -> let h0 = ST.get() in Hacl.Impl.AES.CoreBitSlice.aes_keygen_assist1 n p; let h1 = ST.get() in assume (key_to_bytes m (as_seq h1 n) == Spec.AES.keygen_assist1 (key_to_bytes m (as_seq h0 p)))

inline_for_extraction noextract
let key_expansion_step #m n p =
match m with
| MAES -> Hacl.Impl.AES.CoreNI.key_expansion_step n p
| M32 -> let h0 = ST.get() in Hacl.Impl.AES.CoreBitSlice.key_expansion_step n p; let h1 = ST.get() in assume (let p0 = LSeq.create 16 (u8 0) in
let p = key_to_bytes m (as_seq h0 p) in
| M32 -> let h0 = ST.get() in Hacl.Impl.AES.CoreBitSlice.key_expansion_step n p; let h1 = ST.get() in assume (let p = key_to_bytes m (as_seq h0 p) in
let n1 = key_to_bytes m (as_seq h0 n) in
let k = LSeq.map2 ( ^. ) p (LSeq.update_sub p0 0 12 (LSeq.sub p 4 12)) in
let k = LSeq.map2 ( ^. ) k (LSeq.update_sub p0 0 12 (LSeq.sub k 4 12)) in
let k = LSeq.map2 ( ^. ) k (LSeq.update_sub p0 0 12 (LSeq.sub k 4 12)) in
key_to_bytes m (as_seq h1 n) == LSeq.map2 ( ^. ) n1 k)
key_to_bytes m (as_seq h1 n) == Spec.AES.key_expansion_step p n1)

22 changes: 10 additions & 12 deletions code/aes/Hacl.Impl.AES.Core.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ let keys_to_bytes (m:m_spec) (a:Spec.AES.variant) (b:LSeq.lseq (stelem m) ((Spec
| M32 -> LSeq.create ((Spec.AES.num_rounds a-1) * 16) (u8 0)
| MAES -> uints_to_bytes_be (LSeq.map (fun x -> LSeq.index (vec_v #U128 #1 x) 0) b)

unfold noextract
let keyx_to_bytes (m:m_spec) (a:Spec.AES.variant) (b:LSeq.lseq (stelem m) ((Spec.AES.num_rounds a+1) * v (klen m))) : LSeq.lseq uint8 ((Spec.AES.num_rounds a+1) * 16) =
match m with
| M32 -> LSeq.create ((Spec.AES.num_rounds a+1) * 16) (u8 0)
| MAES -> uints_to_bytes_be (LSeq.map (fun x -> LSeq.index (vec_v #U128 #1 x) 0) b)

unfold noextract
let state (m:m_spec) = lbuffer (stelem m) (stlen m)

Expand Down Expand Up @@ -266,9 +272,7 @@ val aes_keygen_assist0:
Stack unit
(requires (fun h -> live h ok /\ live h ik /\ disjoint ik ok))
(ensures (fun h0 _ h1 -> modifies1 ok h0 h1 /\
(let k = Spec.AES.aes_keygen_assist rcon (key_to_bytes m (as_seq h0 ik)) in
let i = LSeq.index (uints_from_bytes_be #U32 #SEC #4 k) 3 in
key_to_bytes m (as_seq h1 ok) == uints_to_bytes_be (LSeq.create4 i i i i))))
key_to_bytes m (as_seq h1 ok) == Spec.AES.keygen_assist0 rcon (key_to_bytes m (as_seq h0 ik))))

inline_for_extraction noextract
val aes_keygen_assist1:
Expand All @@ -278,9 +282,7 @@ val aes_keygen_assist1:
Stack unit
(requires (fun h -> live h ok /\ live h ik /\ disjoint ik ok))
(ensures (fun h0 _ h1 -> modifies1 ok h0 h1 /\
(let k = Spec.AES.aes_keygen_assist (u8 0) (key_to_bytes m (as_seq h0 ik)) in
let i = LSeq.index (uints_from_bytes_be #U32 #SEC #4 k) 2 in
key_to_bytes m (as_seq h1 ok) == uints_to_bytes_be (LSeq.create4 i i i i))))
key_to_bytes m (as_seq h1 ok) == Spec.AES.keygen_assist1 (key_to_bytes m (as_seq h0 ik))))

inline_for_extraction noextract
val key_expansion_step:
Expand All @@ -290,10 +292,6 @@ val key_expansion_step:
Stack unit
(requires (fun h -> live h prev /\ live h next))
(ensures (fun h0 _ h1 -> modifies1 next h0 h1 /\
(let p0 = LSeq.create 16 (u8 0) in
let p = key_to_bytes m (as_seq h0 prev) in
(let p = key_to_bytes m (as_seq h0 prev) in
let n = key_to_bytes m (as_seq h0 next) in
let k = LSeq.map2 ( ^. ) p (LSeq.update_sub p0 0 12 (LSeq.sub p 4 12)) in
let k = LSeq.map2 ( ^. ) k (LSeq.update_sub p0 0 12 (LSeq.sub k 4 12)) in
let k = LSeq.map2 ( ^. ) k (LSeq.update_sub p0 0 12 (LSeq.sub k 4 12)) in
key_to_bytes m (as_seq h1 next) == LSeq.map2 ( ^. ) n k)))
key_to_bytes m (as_seq h1 next) == Spec.AES.key_expansion_step p n)))
28 changes: 12 additions & 16 deletions code/aes/Hacl.Impl.AES.CoreNI.fst
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,8 @@ val aes_keygen_assist0:
Stack unit
(requires (fun h -> live h ok /\ live h ik))
(ensures (fun h0 _ h1 -> modifies1 ok h0 h1 /\
(let k = Spec.AES.aes_keygen_assist rcon (uints_to_bytes_be (vec_v (LSeq.index (as_seq h0 ik) 0))) in
let i = LSeq.index (uints_from_bytes_be #U32 #SEC #4 k) 3 in
uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 ok) 0)) == uints_to_bytes_be (LSeq.create4 i i i i))))
uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 ok) 0)) ==
Spec.AES.keygen_assist0 rcon (uints_to_bytes_be (vec_v (LSeq.index (as_seq h0 ik) 0)))))

let aes_keygen_assist0 ok ik rcon =
let h0 = ST.get() in
Expand All @@ -343,7 +342,8 @@ let aes_keygen_assist0 ok ik rcon =
vec_u8_16 (Spec.AES.aes_keygen_assist rcon (uints_to_bytes_be (vec_v (LSeq.index (as_seq h0 ik) 0))));
let k = Spec.AES.aes_keygen_assist rcon (uints_to_bytes_be (vec_v (LSeq.index (as_seq h0 ik) 0))) in
let i = LSeq.index (uints_from_bytes_be #U32 #SEC #4 k) 3 in
uints_bytes_be_lemma (uints_to_bytes_be (LSeq.create4 i i i i))
uints_bytes_be_lemma (uints_to_bytes_be (LSeq.create4 i i i i));
vec_create4_3_lemma k

inline_for_extraction noextract
val aes_keygen_assist1:
Expand All @@ -352,9 +352,8 @@ val aes_keygen_assist1:
Stack unit
(requires (fun h -> live h ok /\ live h ik))
(ensures (fun h0 _ h1 -> modifies1 ok h0 h1 /\
(let k = Spec.AES.aes_keygen_assist (u8 0) (uints_to_bytes_be (vec_v (LSeq.index (as_seq h0 ik) 0))) in
let i = LSeq.index (uints_from_bytes_be #U32 #SEC #4 k) 2 in
uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 ok) 0)) == uints_to_bytes_be (LSeq.create4 i i i i))))
uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 ok) 0)) ==
Spec.AES.keygen_assist1 (uints_to_bytes_be (vec_v (LSeq.index (as_seq h0 ik) 0)))))

let aes_keygen_assist1 ok ik =
let h0 = ST.get() in
Expand All @@ -364,7 +363,8 @@ let aes_keygen_assist1 ok ik =
vec_u8_16 (Spec.AES.aes_keygen_assist (u8 0) (uints_to_bytes_be (vec_v (LSeq.index (as_seq h0 ik) 0))));
let k = Spec.AES.aes_keygen_assist (u8 0) (uints_to_bytes_be (vec_v (LSeq.index (as_seq h0 ik) 0))) in
let i = LSeq.index (uints_from_bytes_be #U32 #SEC #4 k) 2 in
uints_bytes_be_lemma (uints_to_bytes_be (LSeq.create4 i i i i))
uints_bytes_be_lemma (uints_to_bytes_be (LSeq.create4 i i i i));
vec_create4_2_lemma k


inline_for_extraction noextract
Expand All @@ -374,21 +374,17 @@ val key_expansion_step:
Stack unit
(requires (fun h -> live h prev /\ live h next))
(ensures (fun h0 _ h1 -> modifies1 next h0 h1 /\
(let p0 = LSeq.create 16 (u8 0) in
let p = uints_to_bytes_be (vec_v (LSeq.index (as_seq h0 prev) 0)) in
(let p = uints_to_bytes_be (vec_v (LSeq.index (as_seq h0 prev) 0)) in
let n = uints_to_bytes_be (vec_v (LSeq.index (as_seq h0 next) 0)) in
let k = LSeq.map2 ( ^. ) p (LSeq.update_sub p0 0 12 (LSeq.sub p 4 12)) in
let k = LSeq.map2 ( ^. ) k (LSeq.update_sub p0 0 12 (LSeq.sub k 4 12)) in
let k = LSeq.map2 ( ^. ) k (LSeq.update_sub p0 0 12 (LSeq.sub k 4 12)) in
uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 next) 0)) == LSeq.map2 ( ^. ) n k)))
uints_to_bytes_be (vec_v (LSeq.index (as_seq h1 next) 0)) == Spec.AES.key_expansion_step p n)))

let key_expansion_step next prev =
let h0 = ST.get() in
let key = prev.(size 0) in
let key = vec_xor key (vec_shift_left key (size 32)) in
let key = vec_xor key (vec_shift_left key (size 32)) in
let key = vec_xor key (vec_shift_left key (size 32)) in
next.(size 0) <- vec_xor next.(size 0) key;
next.(size 0) <- vec_xor key next.(size 0);
let p0 = LSeq.create 16 (u8 0) in
let p = vec_v (LSeq.index (as_seq h0 prev) 0) in
let n = vec_v (LSeq.index (as_seq h0 next) 0) in
Expand All @@ -401,4 +397,4 @@ let key_expansion_step next prev =
vec_shift_l_32_lemma p0 k;
vec_xor_lemma k (LSeq.map (shift_left_i (size 32)) k);
let k = LSeq.map2 ( ^. ) k (LSeq.map (shift_left_i (size 32)) k) in
vec_xor_lemma n k
vec_xor_lemma k n
Loading

0 comments on commit bcd96f0

Please sign in to comment.