Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add sis avx512 and fft avx512 for koalabear #622

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

gbotrel
Copy link
Collaborator

@gbotrel gbotrel commented Feb 2, 2025

Description

This PR adds optimized AVX512 impl for some FFT operations and SIS hash (512 degree/16bits bound) for koalabear.

Benchmark for FFT

BenchmarkFFTDITCosetReference-16     3759174       1195463       -68.20%
BenchmarkFFTDITCosetReference-16     3624691       1188837       -67.20%
BenchmarkFFTDIFReference-16          2873662       631355        -78.03%
BenchmarkFFTDIFReference-16          2859901       634403        -77.82%
BenchmarkFFTDIFReferenceSmall-16     13534         817           -93.97%
BenchmarkFFTDIFReferenceSmall-16     13611         850           -93.76%

Benchmark for SIS

BenchmarkSIS/ring-sis/inputs=1024/log2-bound=16/log2-degree=9-16      58031         3260          -94.38%
BenchmarkSIS/ring-sis/inputs=1024/log2-bound=16/log2-degree=9-16      57965         3261          -94.37%
BenchmarkSIS/ring-sis/inputs=2048/log2-bound=16/log2-degree=9-16      104782        5568          -94.69%
BenchmarkSIS/ring-sis/inputs=2048/log2-bound=16/log2-degree=9-16      104833        5567          -94.69%
BenchmarkSIS/ring-sis/inputs=4096/log2-bound=16/log2-degree=9-16      198348        10172         -94.87%
BenchmarkSIS/ring-sis/inputs=4096/log2-bound=16/log2-degree=9-16      198438        10184         -94.87%
BenchmarkSIS/ring-sis/inputs=8192/log2-bound=16/log2-degree=9-16      385710        19391         -94.97%
BenchmarkSIS/ring-sis/inputs=8192/log2-bound=16/log2-degree=9-16      385378        19400         -94.97%
BenchmarkSIS/ring-sis/inputs=16384/log2-bound=16/log2-degree=9-16     759139        37880         -95.01%
BenchmarkSIS/ring-sis/inputs=16384/log2-bound=16/log2-degree=9-16     758603        37856         -95.01%
BenchmarkSIS/ring-sis/inputs=32768/log2-bound=16/log2-degree=9-16     1505796       74779         -95.03%
BenchmarkSIS/ring-sis/inputs=32768/log2-bound=16/log2-degree=9-16     1506593       74821         -95.03%
BenchmarkSIS/ring-sis/inputs=65536/log2-bound=16/log2-degree=9-16     3005044       149647        -95.02%
BenchmarkSIS/ring-sis/inputs=65536/log2-bound=16/log2-degree=9-16     3008857       149504        -95.03%

benchmark                                                             old MB/s     new MB/s     speedup
BenchmarkSIS/ring-sis/inputs=1024/log2-bound=16/log2-degree=9-16      70.58        1256.40      17.80x
BenchmarkSIS/ring-sis/inputs=1024/log2-bound=16/log2-degree=9-16      70.66        1255.94      17.77x
BenchmarkSIS/ring-sis/inputs=2048/log2-bound=16/log2-degree=9-16      78.18        1471.32      18.82x
BenchmarkSIS/ring-sis/inputs=2048/log2-bound=16/log2-degree=9-16      78.14        1471.65      18.83x
BenchmarkSIS/ring-sis/inputs=4096/log2-bound=16/log2-degree=9-16      82.60        1610.63      19.50x
BenchmarkSIS/ring-sis/inputs=4096/log2-bound=16/log2-degree=9-16      82.56        1608.81      19.49x
BenchmarkSIS/ring-sis/inputs=8192/log2-bound=16/log2-degree=9-16      84.96        1689.85      19.89x
BenchmarkSIS/ring-sis/inputs=8192/log2-bound=16/log2-degree=9-16      85.03        1689.10      19.86x
BenchmarkSIS/ring-sis/inputs=16384/log2-bound=16/log2-degree=9-16     86.33        1730.10      20.04x
BenchmarkSIS/ring-sis/inputs=16384/log2-bound=16/log2-degree=9-16     86.39        1731.20      20.04x
BenchmarkSIS/ring-sis/inputs=32768/log2-bound=16/log2-degree=9-16     87.04        1752.80      20.14x
BenchmarkSIS/ring-sis/inputs=32768/log2-bound=16/log2-degree=9-16     87.00        1751.81      20.14x
BenchmarkSIS/ring-sis/inputs=65536/log2-bound=16/log2-degree=9-16     87.23        1751.75      20.08x
BenchmarkSIS/ring-sis/inputs=65536/log2-bound=16/log2-degree=9-16     87.12        1753.43      20.13x

Refer to the comments in element_vec_F31.go, in particular for SIS:

// this is a specialized unrolled SIS hash for degree = 512 and log2(bound) = 16
	// essentially, the "pure go" algorithm to hash(v), v being a vector of n elements, is as follows:
	// 1. process v in chunks of 256elements
	// 2. for each chunk of 256 elements, do the following:
	//		- load 256 elements from v
	//		- convert from montgomery to regular form
	//		- split the limbs into (k) 512 values; i.e. separate the uint32 into 2 uint16 (this is the "split" part)
	//		- do a FFT(k, fft.DIF, fft.OnCoset())
	//		- multiply by the Ag[i] (i being the chunk index corresponding to the Ag index)
	//		- accumulate the result in res
	//
	// the AVX512 version follows the same logic, but is unrolled and uses AVX512 instructions (duh.)
	// it minimize memory access, unrolls code when possible, leverage ILP (instruction level parallelism)
	// avoid doing reductions mod q when possible, and a flurry of other tricks.
	//
	// the algorithm is as follows:
	// 1. load 256 values from the chunk (uint32);
	// 		- note that we have 2 pointers; 1 at the beginning "x" and 1 at the middle "xm"
	// 		- this enables us to do the first stage of the FFT directly, and save in registers the first half
	// 		for the next stage of the FFT. the second half is stored on the stack.
	// 2. convert from montgomery to regular form
	// 3. split the limbs from the 256 values into 512 values
	// 4. multiply by cosets
	// 5. perform the FFT first stage (512)
	//		- that is butterfly and multiply by twiddles
	// 6. at the end of this first unrolled loop, we have the first 256 values in registers
	// and the second 256 values on the stack.
	// we still need to do the FFT on these 2 halves, then multiply by rag, and accumulate in res.
	//
	// The result is "shuffled", and before calling the FFT inverse, caller need to call sisUnshuffle.
	// This shuffling is due to the fact that we want to process elements in blocks of 16 (use a full AVX512 DWORD register)
	// And the FFT after a certain stage works with a stride of 8, then 4, then 2, then 1.
	// So we need to shuffle the elements; we could unshuffle them here, but since we accumulate our result in res
	// and this can be called multiple times, we prefer to amortize and do the unshuffle only once at the end.

@gbotrel gbotrel requested a review from ivokub February 2, 2025 18:36
@ivokub
Copy link
Collaborator

ivokub commented Feb 4, 2025

Suggested edit:

diff --git a/internal/generator/gkr/template/gkr.test.vectors.gen.go.tmpl b/internal/generator/gkr/template/gkr.test.vectors.gen.go.tmpl
index 71f0d4835..7bf9cea03 100644
--- a/internal/generator/gkr/template/gkr.test.vectors.gen.go.tmpl
+++ b/internal/generator/gkr/template/gkr.test.vectors.gen.go.tmpl
@@ -1,19 +1,19 @@
 import (
 	"encoding/json"
 	"fmt"
+	"hash"
+	"os"
+	"path/filepath"
+	"reflect"
+
+	"github.com/consensys/bavard"
 	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
-    
 	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational"
 	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr"
 	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial"
 	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck"
 	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils"
-	"os"
-	"path/filepath"
-	"reflect"
-	"hash"
 
-    "github.com/consensys/bavard"
 )
 
 func main() {
diff --git a/internal/generator/gkr/test_vectors/main.go b/internal/generator/gkr/test_vectors/main.go
index 37e62d4d5..96ed2b453 100644
--- a/internal/generator/gkr/test_vectors/main.go
+++ b/internal/generator/gkr/test_vectors/main.go
@@ -8,19 +8,18 @@ package main
 import (
 	"encoding/json"
 	"fmt"
-	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
-
-	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational"
-	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr"
-	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial"
-	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck"
-	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils"
 	"hash"
 	"os"
 	"path/filepath"
 	"reflect"
 
 	"github.com/consensys/bavard"
+	fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir"
+	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational"
+	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr"
+	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial"
+	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck"
+	"github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils"
 )
 
 func main() {
diff --git a/internal/generator/main.go b/internal/generator/main.go
index 6038a7129..cf7ed49a6 100644
--- a/internal/generator/main.go
+++ b/internal/generator/main.go
@@ -7,8 +7,6 @@ import (
 	"path/filepath"
 	"sync"
 
-	"github.com/consensys/gnark-crypto/internal/generator/mpcsetup"
-
 	"github.com/consensys/bavard"
 	"github.com/consensys/gnark-crypto/field/generator"
 	fieldConfig "github.com/consensys/gnark-crypto/field/generator/config"
@@ -21,11 +19,11 @@ import (
 	"github.com/consensys/gnark-crypto/internal/generator/edwards/eddsa"
 	"github.com/consensys/gnark-crypto/internal/generator/fflonk"
 	fri "github.com/consensys/gnark-crypto/internal/generator/fri/template"
-
 	"github.com/consensys/gnark-crypto/internal/generator/gkr"
 	"github.com/consensys/gnark-crypto/internal/generator/hash_to_field"
 	"github.com/consensys/gnark-crypto/internal/generator/iop"
 	"github.com/consensys/gnark-crypto/internal/generator/kzg"
+	"github.com/consensys/gnark-crypto/internal/generator/mpcsetup"
 	"github.com/consensys/gnark-crypto/internal/generator/pairing"
 	"github.com/consensys/gnark-crypto/internal/generator/pedersen"
 	"github.com/consensys/gnark-crypto/internal/generator/permutation"
diff --git a/internal/generator/test_vector_utils/generate.go b/internal/generator/test_vector_utils/generate.go
index f91f30069..216e8307d 100644
--- a/internal/generator/test_vector_utils/generate.go
+++ b/internal/generator/test_vector_utils/generate.go
@@ -5,7 +5,6 @@ import (
 
 	"github.com/consensys/bavard"
 	"github.com/consensys/gnark-crypto/internal/generator/config"
-
 	"github.com/consensys/gnark-crypto/internal/generator/gkr"
 	"github.com/consensys/gnark-crypto/internal/generator/polynomial"
 	"github.com/consensys/gnark-crypto/internal/generator/sumcheck"
diff --git a/internal/generator/tower/generate.go b/internal/generator/tower/generate.go
index 0bdbefd73..734e86aa7 100644
--- a/internal/generator/tower/generate.go
+++ b/internal/generator/tower/generate.go
@@ -7,7 +7,6 @@ import (
 
 	"github.com/consensys/bavard"
 	"github.com/consensys/gnark-crypto/internal/generator/config"
-
 	"github.com/consensys/gnark-crypto/internal/generator/tower/asm/amd64"
 )
 

Copy link
Collaborator

@ivokub ivokub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, didn't see anything concerning on Go side. Didn't check assembly correctness though, but seems to be compared against non-assembly implementation and also fuzzed, so I think it is good.

Only a few suggested edits about import ordering (as there were already changes)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants