-
Notifications
You must be signed in to change notification settings - Fork 6.1k
8349721: Add aarch64 intrinsics for ML-KEM #23663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Welcome back ferakocz! A progress list of the required criteria for merging this PR into |
@ferakocz This change now passes all automated pre-integration checks. ℹ️ This project also has non-automated pre-integration requirements. Please see the file CONTRIBUTING.md for details. After integration, the commit message for the final commit will be:
You can use pull request commands such as /summary, /contributor and /issue to adjust it as needed. At the time when this comment was updated there had been 425 new commits pushed to the
As there are no conflicts, your changes will automatically be rebased on top of these commits when integrating. If you prefer to avoid this automatic rebasing, please check the documentation for the /integrate command for further details. As you do not have Committer status in this project an existing Committer must agree to sponsor your change. Possible candidates are the reviewers of this PR (@adinn) but any other Committer may sponsor as well. ➡️ To flag this PR as ready for integration with the above commit message, type |
Webrevs
|
ML-KEM benchmark results of this PR:
ML-KEM no intrinsics
|
address stubAddr; | ||
const char *stubName; | ||
assert(UseKyberIntrinsics, "need Kyber intrinsics support"); | ||
assert(callee()->signature()->size() == 3, "kyber12To16 has 3 parameters"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just as an aside this causes testing of a debug build to fail. The intrinsic has 4 parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this value reset to 4 the ML_DSA test passes for ML_KEM on a debug build.
@ferakocz this pull request can not be integrated into git checkout mlkem-aarch64-intrinsics
git fetch https://git.openjdk.org/jdk.git master
git merge FETCH_HEAD
# resolve conflicts and follow the instructions given by git merge
git commit -m "Merge master"
git push |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ferakocz Thanks for another very good piece of work which appears to me to be functioning correctly and performantly.
The PR suffers from the same problems as the original ML_DSA one i.e. The mapping of data to registers and the overall structure of the generated code and its relation to the related Java code/the original algorithms will be hard for a maintainer to identify. I have reworked your patch to use vector sequences in this draft PR in very much the same way as was done for the ML_DSA PR. This has significantly abstracted and clarified the register mappings that are in use in each kyber generator and has also made the higher level structure of the generated code much easier to follow.
Note that my rework of the generation routines was applied to your original PR after rebasing it on master. Before updating the kyber routines I also generalized a few of the VSeq methods that benefit from being shared by both kyber and dilithium, most notably the montmul routines, and I added a few extra helpers.
The reworked version passes the ML_KEM functional test and gives similar performance improvements for the ML_KEM micro benchmark. The generated code does differ in a few places from what your original patch generates but only superficially - most notable is that a few loads/stores that rely on continued post-increments in the original instead use a constant offset or an add/load pair in the reworked code. This makes a very minor difference to code size and does not seem to affect performance.
I would like you to rework your PR to incorporate these changes because I believe it will make a big difference to maintainability. n.b. it may be easier to integrate my changes by diffing your branch and mine and applying the resulting change set rather than trying to merge the changes. Please let me know if you have problems with the integration and need help.
I still have some further review comments and would also like to see more commenting to explain what the code is doing. However, I think it will be easier to do that after this rework has been integrated into your PR.
// load 96 (6 x 16B) byte values | ||
vs_ld3_post(vin, __ T16B, condensed); | ||
|
||
// expand groups of input bytes in vin to shorts in va and vb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I's like to expand on the data layouts here so that maintenance engineers don't have to work it out every time they look at it. So, I would like to replace this comment as follows
// The front half of sequence vin (vin[0], vin[1] and vin[2])
// holds 48 (16x3) contiguous bytes from memory striped
// horizontally across each of the 16 byte lanes. Equivalently,
// that is 16 pairs of 12-bit integers. Likewise the back half
// holds the next 48 bytes in the same arrangement.
// Each vector in the front half can also be viewed as a vertical
// strip across the 16 pairs of 12 bit integers. Each byte in
// vin[0] stores the low 8 bits of the first int in a pair. Each
// byte in vin[1] stores the high 4 bits of the first int and the
// low 4 bits of the second int. Each byte in vin[2] stores the
// high 8 bits of the second int. Likewise the vectors in second
// half.
// Converting the data to 16-bit shorts requires first of all
// expanding each of the 6 x 16B vectors into 6 corresponding
// pairs of 8H vectors. Mask, shift and add operations on the
// resulting vector pairs can be used to combine 4 and 8 bit
// parts of related 8H vector elements.
//
// The middle vectors (vin[2] and vin[5]) are actually expanded
// twice, one copy manipulated to provide the lower 4 bits
// belonging to the first short in a pair and another copy
// manipulated to provide the higher 4 bits belonging to the
// second short in a pair. This is why the the vector sequences va
// and vb used to hold the expanded 8H elements are of length 8.
// Expand vin[0] into va[0:1], and vin[1] into va[2:3] and va[4:5]
__ ushll(vb[4], __ T8H, vin[4], __ T8B, 0); | ||
__ ushll2(vb[5], __ T8H, vin[4], __ T16B, 0); | ||
|
||
// offset duplicated elements in va and vb by 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make this clearer it should say
// shift lo byte of copy 1 of the middle stripe into the high byte
__ shl(vb[2], __ T8H, vb[2], 8); | ||
__ shl(vb[3], __ T8H, vb[3], 8); | ||
|
||
// expand remaining input bytes in vin to shorts in va and vb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make this clearer it should say
// Expand vin[2] into va[6:7] and vin[5] into vb[6:7] but this
// time pre-shifted by 4 to ensure top bits of input 12-bit int
// are in bit positions [4..11].
__ ushll(vb[6], __ T8H, vin[5], __ T8B, 4); | ||
__ ushll2(vb[7], __ T8H, vin[5], __ T16B, 4); | ||
|
||
// split the duplicated 8 bit values into two distinct 4 bit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make this clearer it should say
// mask hi 4 bits of the 1st 12-bit int in a pair from copy1 and
// shift lo 4 bits of the 2nd 12-bit int in a pair to the bottom of
// copy2
__ ushr(vb[4], __ T8H, vb[4], 4); | ||
__ ushr(vb[5], __ T8H, vb[5], 4); | ||
|
||
// sum resulting short values into the front halves of va and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be replaced to clarify details of the ordering for summing and grouping
// sum hi 4 bits and lo 8 bits of the 1st 12-bit int in each pair and
// hi 8 bits plus lo 4 bits of the 2nd 12-bit int in each pair
// n.b. the ordering ensures: i) inputs are consumed before they
// are overwritten ii) the order of 16-bit results across successive
// pairs of vectors in va and then vb reflects the order of the
// corresponding 12-bit inputs
// montmul the first and second pair of values loaded into vs4 | ||
// in order and then with one pair reversed storing the two | ||
// results in vs1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// montmul the first and second pair of values loaded into vs4 | |
// in order and then with one pair reversed storing the two | |
// results in vs1 | |
// compute 4 montmul cross-products for pairs (a2,a3) and (b2,b3) | |
// i.e. montmul the first and second halves of vs4 in order and | |
// then with one sequence reversed storing the two results in vs1 | |
// | |
// vs1[0] <- montmul(a2, b2) | |
// vs1[1] <- montmul(a3, b3) | |
// vs1[2] <- montmul(a2, b3) | |
// vs1[3] <- montmul(a3, b2) |
// for each pair of results pick the second value in the first | ||
// pair to create a sequence that we montmul by the zetas | ||
// i.e. we want sequence <vs3[1], vs1[1]> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// for each pair of results pick the second value in the first | |
// pair to create a sequence that we montmul by the zetas | |
// i.e. we want sequence <vs3[1], vs1[1]> | |
// montmul result 2 of each cross-product i.e. (a1*b1, a3*b3) by a zeta. | |
// We can schedule two montmuls at a time if we use a suitable vector | |
// sequence <vs3[1], vs1[1]>. |
// i.e. we want sequence <vs3[1], vs1[1]> | ||
int delta = vs1[1]->encoding() - vs3[1]->encoding(); | ||
VSeq<2> vs5(vs3[1], delta); | ||
kyber_montmul16(vs5, vz, vs5, vs_front(vs2), vq); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kyber_montmul16(vs5, vz, vs5, vs_front(vs2), vq); | |
// vs3[1] <- montmul(montmul(a1, b1), z0) | |
// vs1[1] <- montmul(montmul(a3, b3), z1) | |
kyber_montmul16(vs5, vz, vs5, vs_front(vs2), vq); |
int delta = vs1[1]->encoding() - vs3[1]->encoding(); | ||
VSeq<2> vs5(vs3[1], delta); | ||
kyber_montmul16(vs5, vz, vs5, vs_front(vs2), vq); | ||
// add results in pairs storing in vs3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// add results in pairs storing in vs3 | |
// add results in pairs storing in vs3 | |
// vs3[0] <- montmul(a0, b0) + montmul(montmul(a1, b1), z0); | |
// vs3[1] <- montmul(a0, b1) + montmul(a1, b0); |
kyber_montmul16(vs5, vz, vs5, vs_front(vs2), vq); | ||
// add results in pairs storing in vs3 | ||
vs_addv(vs_front(vs3), __ T8H, vs_even(vs3), vs_odd(vs3)); | ||
vs_addv(vs_back(vs3), __ T8H, vs_even(vs1), vs_odd(vs1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vs_addv(vs_back(vs3), __ T8H, vs_even(vs1), vs_odd(vs1)); | |
// vs3[2] <- montmul(a2, b2) + montmul(montmul(a3, b3), z1); | |
// vs3[3] <- montmul(a2, b3) + montmul(a3, b2); | |
vs_addv(vs_back(vs3), __ T8H, vs_even(vs1), vs_odd(vs1)); |
// add results in pairs storing in vs3 | ||
vs_addv(vs_front(vs3), __ T8H, vs_even(vs3), vs_odd(vs3)); | ||
vs_addv(vs_back(vs3), __ T8H, vs_even(vs1), vs_odd(vs1)); | ||
// montmul result by constant vc and store result in vs1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// montmul result by constant vc and store result in vs1 | |
// vs1 <- montmul(vs3, montRSquareModQ) |
// store the four results as two interleaved pairs of | ||
// quadwords |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// store the four results as two interleaved pairs of | |
// quadwords | |
// store back the two pairs of result vectors de-interleaved as 8H elements | |
// i.e. storing each pairs of shorts striped across a register pair adjacent | |
// in memory |
Hi Ferenc, Sorry, but I still had a few comments to add to the KyberNTTMult routine to clarify exactly how the load, compute and store operations relate to the original Java source. That's the only remaining code that I felt needed further clarification for maintainers. So, after you work through them I can approve the PR. |
No problem , it was easy to make the changes. Thanks again! |
@ferakocz I reran test jtreg:test/jdk/sun/security/provider/acvp/Launcher.java and hit a Java assertion:
The offending code is this:
I believe the logic is reversed here i.e. it should be:
Does that sound right? |
Aarrrrgh, yes. I forgot to negate that condition when I went from throwing an exception to assert, and I also thought, incorrectly, that -ea would enable my assertions when I tested :-( . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ferakocz I reran the test and the perf test and all appears to be good. Nice work!
/sponsor |
/integrate |
/sponsor |
Going to push as commit 465c8e6.
Your commit was automatically rebased without conflicts. |
@@ -703,6 +714,7 @@ void VM_Version::initialize_cpu_information(void) { | |||
get_compatible_board(_cpu_desc + desc_len, CPU_DETAILED_DESC_BUF_SIZE - desc_len); | |||
desc_len = (int)strlen(_cpu_desc); | |||
snprintf(_cpu_desc + desc_len, CPU_DETAILED_DESC_BUF_SIZE - desc_len, " %s", _features_string); | |||
fprintf(stderr, "_features_string = \"%s\"", _features_string); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this line added by mistake? Looks like a leftover.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@iwanowww Thanks for the heads up!
By using the aarch64 vector registers the speed of the computation of the ML-KEM algorithms (key generation, encapsulation, decapsulation) can be approximately doubled.
Progress
Issue
Reviewers
Reviewing
Using
git
Checkout this PR locally:
$ git fetch https://git.openjdk.org/jdk.git pull/23663/head:pull/23663
$ git checkout pull/23663
Update a local copy of the PR:
$ git checkout pull/23663
$ git pull https://git.openjdk.org/jdk.git pull/23663/head
Using Skara CLI tools
Checkout this PR locally:
$ git pr checkout 23663
View PR using the GUI difftool:
$ git pr show -t 23663
Using diff file
Download this PR as a diff file:
https://git.openjdk.org/jdk/pull/23663.diff
Using Webrev
Link to Webrev Comment