Skip to content

8349721: Add aarch64 intrinsics for ML-KEM #23663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from

Conversation

ferakocz
Copy link
Contributor

@ferakocz ferakocz commented Feb 17, 2025

By using the aarch64 vector registers the speed of the computation of the ML-KEM algorithms (key generation, encapsulation, decapsulation) can be approximately doubled.


Progress

  • Change must be properly reviewed (1 review required, with at least 1 Reviewer)
  • Change must not contain extraneous whitespace
  • Commit message must refer to an issue

Issue

  • JDK-8349721: Add aarch64 intrinsics for ML-KEM (Enhancement - P3)

Reviewers

Reviewing

Using git

Checkout this PR locally:
$ git fetch https://git.openjdk.org/jdk.git pull/23663/head:pull/23663
$ git checkout pull/23663

Update a local copy of the PR:
$ git checkout pull/23663
$ git pull https://git.openjdk.org/jdk.git pull/23663/head

Using Skara CLI tools

Checkout this PR locally:
$ git pr checkout 23663

View PR using the GUI difftool:
$ git pr show -t 23663

Using diff file

Download this PR as a diff file:
https://git.openjdk.org/jdk/pull/23663.diff

Using Webrev

Link to Webrev Comment

@bridgekeeper
Copy link

bridgekeeper bot commented Feb 17, 2025

👋 Welcome back ferakocz! A progress list of the required criteria for merging this PR into master will be added to the body of your pull request. There are additional pull request commands available for use with this pull request.

@openjdk
Copy link

openjdk bot commented Feb 17, 2025

@ferakocz This change now passes all automated pre-integration checks.

ℹ️ This project also has non-automated pre-integration requirements. Please see the file CONTRIBUTING.md for details.

After integration, the commit message for the final commit will be:

8349721: Add aarch64 intrinsics for ML-KEM

Reviewed-by: adinn

You can use pull request commands such as /summary, /contributor and /issue to adjust it as needed.

At the time when this comment was updated there had been 425 new commits pushed to the master branch:

As there are no conflicts, your changes will automatically be rebased on top of these commits when integrating. If you prefer to avoid this automatic rebasing, please check the documentation for the /integrate command for further details.

As you do not have Committer status in this project an existing Committer must agree to sponsor your change. Possible candidates are the reviewers of this PR (@adinn) but any other Committer may sponsor as well.

➡️ To flag this PR as ready for integration with the above commit message, type /integrate in a new comment. (Afterwards, your sponsor types /sponsor in a new comment to perform the integration).

@openjdk
Copy link

openjdk bot commented Feb 17, 2025

@ferakocz The following labels will be automatically applied to this pull request:

  • graal
  • hotspot
  • security

When this pull request is ready to be reviewed, an "RFR" email will be sent to the corresponding mailing lists. If you would like to change these labels, use the /label pull request command.

@openjdk openjdk bot added the rfr Pull request is ready for review label Feb 17, 2025
@mlbridge
Copy link

mlbridge bot commented Feb 17, 2025

@mcpowers
Copy link
Contributor

ML-KEM benchmark results of this PR:

MLKEM.decapsulate  512 11.80 us/op
MLKEM.decapsulate  768 18.19 us/op 
MLKEM.decapsulate 1024 29.57 us/op
MLKEM.encapsulate  512  8.80 us/op 
MLKEM.encapsulate  768 13.49 us/op  
MLKEM.encapsulate 1024 22.53 us/op
MLKEM.keygen       512  7.49 us/op 
MLKEM.keygen       768 11.22 us/op  
MLKEM.keygen      1024 19.08 us/op 

ML-KEM no intrinsics

MLKEM.decapsulate  512 31.23 us/op  
MLKEM.decapsulate  768 50.09 us/op 
MLKEM.decapsulate 1024 75.92 us/op
MLKEM.encapsulate  512 22.72 us/op  
MLKEM.encapsulate  768 37.27 us/op 
MLKEM.encapsulate 1024 59.69 us/op
MLKEM.keygen       512 17.95 us/op  
MLKEM.keygen       768 30.95 us/op 
MLKEM.keygen      1024 49.04 us/op

address stubAddr;
const char *stubName;
assert(UseKyberIntrinsics, "need Kyber intrinsics support");
assert(callee()->signature()->size() == 3, "kyber12To16 has 3 parameters");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as an aside this causes testing of a debug build to fail. The intrinsic has 4 parameters.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this value reset to 4 the ML_DSA test passes for ML_KEM on a debug build.

@openjdk
Copy link

openjdk bot commented Mar 22, 2025

@ferakocz this pull request can not be integrated into master due to one or more merge conflicts. To resolve these merge conflicts and update this pull request you can run the following commands in the local repository for your personal fork:

git checkout mlkem-aarch64-intrinsics
git fetch https://git.openjdk.org/jdk.git master
git merge FETCH_HEAD
# resolve conflicts and follow the instructions given by git merge
git commit -m "Merge master"
git push

@openjdk openjdk bot added the merge-conflict Pull request has merge conflict with target branch label Mar 22, 2025
@openjdk openjdk bot removed the merge-conflict Pull request has merge conflict with target branch label Mar 23, 2025
Copy link
Contributor

@adinn adinn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ferakocz Thanks for another very good piece of work which appears to me to be functioning correctly and performantly.

The PR suffers from the same problems as the original ML_DSA one i.e. The mapping of data to registers and the overall structure of the generated code and its relation to the related Java code/the original algorithms will be hard for a maintainer to identify. I have reworked your patch to use vector sequences in this draft PR in very much the same way as was done for the ML_DSA PR. This has significantly abstracted and clarified the register mappings that are in use in each kyber generator and has also made the higher level structure of the generated code much easier to follow.

Note that my rework of the generation routines was applied to your original PR after rebasing it on master. Before updating the kyber routines I also generalized a few of the VSeq methods that benefit from being shared by both kyber and dilithium, most notably the montmul routines, and I added a few extra helpers.

The reworked version passes the ML_KEM functional test and gives similar performance improvements for the ML_KEM micro benchmark. The generated code does differ in a few places from what your original patch generates but only superficially - most notable is that a few loads/stores that rely on continued post-increments in the original instead use a constant offset or an add/load pair in the reworked code. This makes a very minor difference to code size and does not seem to affect performance.

I would like you to rework your PR to incorporate these changes because I believe it will make a big difference to maintainability. n.b. it may be easier to integrate my changes by diffing your branch and mine and applying the resulting change set rather than trying to merge the changes. Please let me know if you have problems with the integration and need help.

I still have some further review comments and would also like to see more commenting to explain what the code is doing. However, I think it will be easier to do that after this rework has been integrated into your PR.

// load 96 (6 x 16B) byte values
vs_ld3_post(vin, __ T16B, condensed);

// expand groups of input bytes in vin to shorts in va and vb
Copy link
Contributor

@adinn adinn Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I's like to expand on the data layouts here so that maintenance engineers don't have to work it out every time they look at it. So, I would like to replace this comment as follows

// The front half of sequence vin (vin[0], vin[1] and vin[2])
// holds 48 (16x3) contiguous bytes from memory striped
// horizontally across each of the 16 byte lanes. Equivalently,
// that is 16 pairs of 12-bit integers. Likewise the back half
// holds the next 48 bytes in the same arrangement.

// Each vector in the front half can also be viewed as a vertical
// strip across the 16 pairs of 12 bit integers. Each byte in
// vin[0] stores the low 8 bits of the first int in a pair. Each
// byte in vin[1] stores the high 4 bits of the first int and the
// low 4 bits of the second int. Each byte in vin[2] stores the
// high 8 bits of the second int. Likewise the vectors in second
// half.

// Converting the data to 16-bit shorts requires first of all
// expanding each of the 6 x 16B vectors into 6 corresponding
// pairs of 8H vectors. Mask, shift and add operations on the
// resulting vector pairs can be used to combine 4 and 8 bit
// parts of related 8H vector elements.
//
// The middle vectors (vin[2] and vin[5]) are actually expanded
// twice, one copy manipulated to provide the lower 4 bits
// belonging to the first short in a pair and another copy
// manipulated to provide the higher 4 bits belonging to the
// second short in a pair. This is why the the vector sequences va
// and vb used to hold the expanded 8H elements are of length 8.

// Expand vin[0] into va[0:1], and vin[1] into va[2:3] and va[4:5]

__ ushll(vb[4], __ T8H, vin[4], __ T8B, 0);
__ ushll2(vb[5], __ T8H, vin[4], __ T16B, 0);

// offset duplicated elements in va and vb by 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this clearer it should say

// shift lo byte of copy 1 of the middle stripe into the high byte

__ shl(vb[2], __ T8H, vb[2], 8);
__ shl(vb[3], __ T8H, vb[3], 8);

// expand remaining input bytes in vin to shorts in va and vb
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this clearer it should say

// Expand vin[2] into va[6:7] and vin[5] into vb[6:7] but this
// time pre-shifted by 4 to ensure top bits of input 12-bit int
// are in bit positions [4..11].

__ ushll(vb[6], __ T8H, vin[5], __ T8B, 4);
__ ushll2(vb[7], __ T8H, vin[5], __ T16B, 4);

// split the duplicated 8 bit values into two distinct 4 bit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this clearer it should say

// mask hi 4 bits of the 1st 12-bit int in a pair from copy1 and
// shift lo 4 bits of the 2nd 12-bit int in a pair to the bottom of
// copy2

__ ushr(vb[4], __ T8H, vb[4], 4);
__ ushr(vb[5], __ T8H, vb[5], 4);

// sum resulting short values into the front halves of va and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be replaced to clarify details of the ordering for summing and grouping

// sum hi 4 bits and lo 8 bits of the 1st 12-bit int in each pair and
// hi 8 bits plus lo 4 bits of the 2nd 12-bit int in each pair

// n.b. the ordering ensures: i) inputs are consumed before they
// are overwritten ii) the order of 16-bit results across successive
// pairs of vectors in va and then vb reflects the order of the
// corresponding 12-bit inputs

Comment on lines 5672 to 5674
// montmul the first and second pair of values loaded into vs4
// in order and then with one pair reversed storing the two
// results in vs1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// montmul the first and second pair of values loaded into vs4
// in order and then with one pair reversed storing the two
// results in vs1
// compute 4 montmul cross-products for pairs (a2,a3) and (b2,b3)
// i.e. montmul the first and second halves of vs4 in order and
// then with one sequence reversed storing the two results in vs1
//
// vs1[0] <- montmul(a2, b2)
// vs1[1] <- montmul(a3, b3)
// vs1[2] <- montmul(a2, b3)
// vs1[3] <- montmul(a3, b2)

Comment on lines 5678 to 5680
// for each pair of results pick the second value in the first
// pair to create a sequence that we montmul by the zetas
// i.e. we want sequence <vs3[1], vs1[1]>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// for each pair of results pick the second value in the first
// pair to create a sequence that we montmul by the zetas
// i.e. we want sequence <vs3[1], vs1[1]>
// montmul result 2 of each cross-product i.e. (a1*b1, a3*b3) by a zeta.
// We can schedule two montmuls at a time if we use a suitable vector
// sequence <vs3[1], vs1[1]>.

// i.e. we want sequence <vs3[1], vs1[1]>
int delta = vs1[1]->encoding() - vs3[1]->encoding();
VSeq<2> vs5(vs3[1], delta);
kyber_montmul16(vs5, vz, vs5, vs_front(vs2), vq);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
kyber_montmul16(vs5, vz, vs5, vs_front(vs2), vq);
// vs3[1] <- montmul(montmul(a1, b1), z0)
// vs1[1] <- montmul(montmul(a3, b3), z1)
kyber_montmul16(vs5, vz, vs5, vs_front(vs2), vq);

int delta = vs1[1]->encoding() - vs3[1]->encoding();
VSeq<2> vs5(vs3[1], delta);
kyber_montmul16(vs5, vz, vs5, vs_front(vs2), vq);
// add results in pairs storing in vs3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// add results in pairs storing in vs3
// add results in pairs storing in vs3
// vs3[0] <- montmul(a0, b0) + montmul(montmul(a1, b1), z0);
// vs3[1] <- montmul(a0, b1) + montmul(a1, b0);

kyber_montmul16(vs5, vz, vs5, vs_front(vs2), vq);
// add results in pairs storing in vs3
vs_addv(vs_front(vs3), __ T8H, vs_even(vs3), vs_odd(vs3));
vs_addv(vs_back(vs3), __ T8H, vs_even(vs1), vs_odd(vs1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
vs_addv(vs_back(vs3), __ T8H, vs_even(vs1), vs_odd(vs1));
// vs3[2] <- montmul(a2, b2) + montmul(montmul(a3, b3), z1);
// vs3[3] <- montmul(a2, b3) + montmul(a3, b2);
vs_addv(vs_back(vs3), __ T8H, vs_even(vs1), vs_odd(vs1));

// add results in pairs storing in vs3
vs_addv(vs_front(vs3), __ T8H, vs_even(vs3), vs_odd(vs3));
vs_addv(vs_back(vs3), __ T8H, vs_even(vs1), vs_odd(vs1));
// montmul result by constant vc and store result in vs1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// montmul result by constant vc and store result in vs1
// vs1 <- montmul(vs3, montRSquareModQ)

Comment on lines 5689 to 5690
// store the four results as two interleaved pairs of
// quadwords
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// store the four results as two interleaved pairs of
// quadwords
// store back the two pairs of result vectors de-interleaved as 8H elements
// i.e. storing each pairs of shorts striped across a register pair adjacent
// in memory

@adinn
Copy link
Contributor

adinn commented Apr 15, 2025

@ferakocz

Hi Ferenc,

Sorry, but I still had a few comments to add to the KyberNTTMult routine to clarify exactly how the load, compute and store operations relate to the original Java source. That's the only remaining code that I felt needed further clarification for maintainers. So, after you work through them I can approve the PR.

@ferakocz
Copy link
Contributor Author

@ferakocz

Hi Ferenc,

Sorry, but I still had a few comments to add to the KyberNTTMult routine to clarify exactly how the load, compute and store operations relate to the original Java source. That's the only remaining code that I felt needed further clarification for maintainers. So, after you work through them I can approve the PR.

No problem , it was easy to make the changes. Thanks again!

@adinn
Copy link
Contributor

adinn commented Apr 15, 2025

@ferakocz I reran test jtreg:test/jdk/sun/security/provider/acvp/Launcher.java and hit a Java assertion:

>> ML-KEM-512 encapsulation
1 STDERR:
java.lang.AssertionError
	at java.base/com.sun.crypto.provider.ML_KEM.twelve2Sixteen(ML_KEM.java:1371)
	at java.base/com.sun.crypto.provider.ML_KEM.decodePoly(ML_KEM.java:1408)
	at java.base/com.sun.crypto.provider.ML_KEM.decodeVector(ML_KEM.java:1337)
	at java.base/com.sun.crypto.provider.ML_KEM.kPkeEncrypt(ML_KEM.java:712)
	at java.base/com.sun.crypto.provider.ML_KEM.encapsulate(ML_KEM.java:555)
	at java.base/com.sun.crypto.provider.ML_KEM_Impls$K.implEncapsulate(ML_KEM_Impls.java:134)
	at java.base/sun.security.provider.NamedKEM$KeyConsumerImpl.engineEncapsulate(NamedKEM.java:124)
	at java.base/javax.crypto.KEM$Encapsulator.encapsulate(KEM.java:265)
	at java.base/javax.crypto.KEM$Encapsulator.encapsulate(KEM.java:225)
	at ML_KEM_Test.encapDecapTest(ML_KEM_Test.java:98)
	at ML_KEM_Test.run(ML_KEM_Test.java:41)
	at Launcher.run(Launcher.java:160)
	at Launcher.main(Launcher.java:122)
	at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:104)
	at java.base/java.lang.reflect.Method.invoke(Method.java:565)
	at com.sun.javatest.regtest.agent.MainActionHelper$AgentVMRunnable.run(MainActionHelper.java:335)
	at java.base/java.lang.Thread.run(Thread.java:1447)

JavaTest Message: Test threw exception: java.lang.AssertionError
JavaTest Message: shutting down test

The offending code is this:

private void twelve2Sixteen(byte[] condensed, int index,
                            short[] parsed, int parsedLength) {
    int i = parsedLength / 64;
    int remainder = parsedLength - i * 64;
    if (remainder != 0) {
        i++;
    }
    assert (((remainder != 0) && (remainder != 48)) || <== assert here
        index + i * 96 > condensed.length);

I believe the logic is reversed here i.e. it should be:

    assert ((remainder == 0) || (remainder == 48)) &&
        index + i * 96 <= condensed.length);

Does that sound right?

@ferakocz
Copy link
Contributor Author

@ferakocz I reran test jtreg:test/jdk/sun/security/provider/acvp/Launcher.java and hit a Java assertion:

>> ML-KEM-512 encapsulation
1 STDERR:
java.lang.AssertionError
	at java.base/com.sun.crypto.provider.ML_KEM.twelve2Sixteen(ML_KEM.java:1371)
	at java.base/com.sun.crypto.provider.ML_KEM.decodePoly(ML_KEM.java:1408)
	at java.base/com.sun.crypto.provider.ML_KEM.decodeVector(ML_KEM.java:1337)
	at java.base/com.sun.crypto.provider.ML_KEM.kPkeEncrypt(ML_KEM.java:712)
	at java.base/com.sun.crypto.provider.ML_KEM.encapsulate(ML_KEM.java:555)
	at java.base/com.sun.crypto.provider.ML_KEM_Impls$K.implEncapsulate(ML_KEM_Impls.java:134)
	at java.base/sun.security.provider.NamedKEM$KeyConsumerImpl.engineEncapsulate(NamedKEM.java:124)
	at java.base/javax.crypto.KEM$Encapsulator.encapsulate(KEM.java:265)
	at java.base/javax.crypto.KEM$Encapsulator.encapsulate(KEM.java:225)
	at ML_KEM_Test.encapDecapTest(ML_KEM_Test.java:98)
	at ML_KEM_Test.run(ML_KEM_Test.java:41)
	at Launcher.run(Launcher.java:160)
	at Launcher.main(Launcher.java:122)
	at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:104)
	at java.base/java.lang.reflect.Method.invoke(Method.java:565)
	at com.sun.javatest.regtest.agent.MainActionHelper$AgentVMRunnable.run(MainActionHelper.java:335)
	at java.base/java.lang.Thread.run(Thread.java:1447)

JavaTest Message: Test threw exception: java.lang.AssertionError
JavaTest Message: shutting down test

The offending code is this:

private void twelve2Sixteen(byte[] condensed, int index,
                            short[] parsed, int parsedLength) {
    int i = parsedLength / 64;
    int remainder = parsedLength - i * 64;
    if (remainder != 0) {
        i++;
    }
    assert (((remainder != 0) && (remainder != 48)) || <== assert here
        index + i * 96 > condensed.length);

I believe the logic is reversed here i.e. it should be:

    assert ((remainder == 0) || (remainder == 48)) &&
        index + i * 96 <= condensed.length);

Does that sound right?

Aarrrrgh, yes. I forgot to negate that condition when I went from throwing an exception to assert, and I also thought, incorrectly, that -ea would enable my assertions when I tested :-( .
Thanks a lot for catching it!

Copy link
Contributor

@adinn adinn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ferakocz I reran the test and the perf test and all appears to be good. Nice work!

@openjdk openjdk bot added the ready Pull request is ready to be integrated label Apr 16, 2025
@ferakocz
Copy link
Contributor Author

@ferakocz I reran the test and the perf test and all appears to be good. Nice work!

@adinn Thanks, Andrew, for all the help, it really looks nicer than it looked before your review! Would you /sponsor the integration?

@adinn
Copy link
Contributor

adinn commented Apr 16, 2025

/sponsor

@openjdk
Copy link

openjdk bot commented Apr 16, 2025

@adinn The change author (@ferakocz) must issue an integrate command before the integration can be sponsored.

@ferakocz
Copy link
Contributor Author

/integrate

@adinn
Copy link
Contributor

adinn commented Apr 16, 2025

/sponsor

@openjdk openjdk bot added the sponsor Pull request is ready to be sponsored label Apr 16, 2025
@openjdk
Copy link

openjdk bot commented Apr 16, 2025

@ferakocz
Your change (at version 3c3bca6) is now ready to be sponsored by a Committer.

@openjdk
Copy link

openjdk bot commented Apr 16, 2025

Going to push as commit 465c8e6.
Since your change was applied there have been 425 commits pushed to the master branch:

Your commit was automatically rebased without conflicts.

@openjdk openjdk bot added the integrated Pull request has been integrated label Apr 16, 2025
@openjdk openjdk bot closed this Apr 16, 2025
@openjdk openjdk bot removed ready Pull request is ready to be integrated rfr Pull request is ready for review sponsor Pull request is ready to be sponsored labels Apr 16, 2025
@openjdk
Copy link

openjdk bot commented Apr 16, 2025

@adinn @ferakocz Pushed as commit 465c8e6.

💡 You may see a message that your pull request was closed with unmerged commits. This can be safely ignored.

@@ -703,6 +714,7 @@ void VM_Version::initialize_cpu_information(void) {
get_compatible_board(_cpu_desc + desc_len, CPU_DETAILED_DESC_BUF_SIZE - desc_len);
desc_len = (int)strlen(_cpu_desc);
snprintf(_cpu_desc + desc_len, CPU_DETAILED_DESC_BUF_SIZE - desc_len, " %s", _features_string);
fprintf(stderr, "_features_string = \"%s\"", _features_string);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this line added by mistake? Looks like a leftover.

Copy link
Contributor Author

@ferakocz ferakocz Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iwanowww Oooops, yes. Addressed in #24717 together with another forgotten (though not strictly necessary) change. Thanks for catching it!
Could you review that (really short) PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ferakocz That link has the right label but leads to the wrong place. It should point here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@iwanowww Thanks for the heads up!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Development

Successfully merging this pull request may close these issues.

4 participants