Skip to content

Introduce a new feature to use dynamic dispatch to select between ADX and portable implementations at runtime #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions bindings/go/assembly.S
Original file line number Diff line number Diff line change
@@ -5,19 +5,19 @@
# define blst_sha256_block_data_order blst_sha256_block_ssse3
# endif
# include "elf/sha256-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "elf/ctx_inverse_mod_384-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "elf/ctq_inverse_mod_384-x86_64.s"
# endif
# include "elf/add_mod_384-x86_64.s"
# include "elf/add_mod_384x384-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "elf/mulx_mont_384-x86_64.s"
# include "elf/mulx_mont_256-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "elf/mulq_mont_384-x86_64.s"
# include "elf/mulq_mont_256-x86_64.s"
# endif
@@ -27,19 +27,19 @@
# include "elf/ct_is_square_mod_384-x86_64.s"
# elif defined(_WIN64) || defined(__CYGWIN__)
# include "coff/sha256-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "coff/ctx_inverse_mod_384-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "coff/ctq_inverse_mod_384-x86_64.s"
# endif
# include "coff/add_mod_384-x86_64.s"
# include "coff/add_mod_384x384-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "coff/mulx_mont_384-x86_64.s"
# include "coff/mulx_mont_256-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "coff/mulq_mont_384-x86_64.s"
# include "coff/mulq_mont_256-x86_64.s"
# endif
@@ -49,19 +49,19 @@
# include "coff/ct_is_square_mod_384-x86_64.s"
# elif defined(__APPLE__)
# include "mach-o/sha256-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "mach-o/ctx_inverse_mod_384-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "mach-o/ctq_inverse_mod_384-x86_64.s"
# endif
# include "mach-o/add_mod_384-x86_64.s"
# include "mach-o/add_mod_384x384-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "mach-o/mulx_mont_384-x86_64.s"
# include "mach-o/mulx_mont_256-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "mach-o/mulq_mont_384-x86_64.s"
# include "mach-o/mulq_mont_256-x86_64.s"
# endif
3 changes: 3 additions & 0 deletions bindings/rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -33,6 +33,9 @@ portable = []
# Enable ADX even if the host CPU doesn't support it.
# Binary can be executed on Broadwell+ and Ryzen+ systems.
force-adx = []
# Compile with dynamic dispatch, detecting CPU features at runtime.
# Binary can be executed on all systems.
dynamic = []
# Suppress multi-threading.
# Engaged on wasm32 target architecture automatically.
no-threads = []
30 changes: 24 additions & 6 deletions bindings/rust/build.rs
Original file line number Diff line number Diff line change
@@ -124,20 +124,29 @@ fn main() {
} else {
cc.define("__BLST_NO_ASM__", None);
}
match (cfg!(feature = "portable"), cfg!(feature = "force-adx")) {
(true, false) => {
match (cfg!(feature = "portable"), cfg!(feature = "force-adx"), cfg!(feature = "dynamic")) {
(true, false, false) => {
println!("Compiling in portable mode without ISA extensions");
cc.define("__BLST_PORTABLE__", None);
}
(false, true) => {
(false, true, false) => {
if target_arch.eq("x86_64") {
println!("Enabling ADX support via `force-adx` feature");
cc.define("__ADX__", None);
} else {
println!("`force-adx` is ignored for non-x86_64 targets");
}
}
(false, false) => {
(false, false, true) => {
if target_arch.eq("x86_64") {
println!("Enabling dynamic dispatch support via `dynamic` feature");
cc.define("__ADX__", None);
cc.define("__BLST_DYNAMIC__", None);
} else {
println!("`dynamic` is ignored for non-x86_64 targets");
}
}
(false, false, false) => {
if target_arch.eq("x86_64") {
// If target-cpu is specified on the rustc command line,
// then obey the resulting target-features.
@@ -170,8 +179,17 @@ fn main() {
}
}
}
(true, true) => panic!(
"Cannot compile with both `portable` and `force-adx` features"
(true, true, false) => panic!(
"Cannot compile with both `portable` and `force-adx`"
),
(true, false, true) => panic!(
"Cannot compile with both `portable` and `dynamic`"
),
(false, true, true) => panic!(
"Cannot compile with both `force-adx` and `dynamic`"
),
(true, true, true) => panic!(
"Cannot compile with `portable`, `force-adx` and `dynamic`"
),
}
if env::var("CARGO_CFG_TARGET_ENV").unwrap().eq("msvc") {
24 changes: 12 additions & 12 deletions build/assembly.S
Original file line number Diff line number Diff line change
@@ -5,19 +5,19 @@
# define blst_sha256_block_data_order blst_sha256_block_ssse3
# endif
# include "elf/sha256-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "elf/ctx_inverse_mod_384-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "elf/ctq_inverse_mod_384-x86_64.s"
# endif
# include "elf/add_mod_384-x86_64.s"
# include "elf/add_mod_384x384-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "elf/mulx_mont_384-x86_64.s"
# include "elf/mulx_mont_256-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "elf/mulq_mont_384-x86_64.s"
# include "elf/mulq_mont_256-x86_64.s"
# endif
@@ -27,19 +27,19 @@
# include "elf/ct_is_square_mod_384-x86_64.s"
# elif defined(_WIN64) || defined(__CYGWIN__)
# include "coff/sha256-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "coff/ctx_inverse_mod_384-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "coff/ctq_inverse_mod_384-x86_64.s"
# endif
# include "coff/add_mod_384-x86_64.s"
# include "coff/add_mod_384x384-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "coff/mulx_mont_384-x86_64.s"
# include "coff/mulx_mont_256-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "coff/mulq_mont_384-x86_64.s"
# include "coff/mulq_mont_256-x86_64.s"
# endif
@@ -49,19 +49,19 @@
# include "coff/ct_is_square_mod_384-x86_64.s"
# elif defined(__APPLE__)
# include "mach-o/sha256-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "mach-o/ctx_inverse_mod_384-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "mach-o/ctq_inverse_mod_384-x86_64.s"
# endif
# include "mach-o/add_mod_384-x86_64.s"
# include "mach-o/add_mod_384x384-x86_64.s"
# if defined(__ADX__) || defined(__BLST_PORTABLE__)
# if defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "mach-o/mulx_mont_384-x86_64.s"
# include "mach-o/mulx_mont_256-x86_64.s"
# endif
# if !defined(__ADX__) || defined(__BLST_PORTABLE__)
# if !defined(__ADX__) || defined(__BLST_PORTABLE__) || defined(__BLST_DYNAMIC__)
# include "mach-o/mulq_mont_384-x86_64.s"
# include "mach-o/mulq_mont_256-x86_64.s"
# endif
15 changes: 15 additions & 0 deletions src/cpuid.c
Original file line number Diff line number Diff line change
@@ -4,8 +4,17 @@
* SPDX-License-Identifier: Apache-2.0
*/

#include "cpuid.h"

int __blst_platform_cap = 0;

int __blst_platform_cap_initialized = 0;

#define __blst_cpuid_run_once() \
if (__blst_platform_cap_initialized) \
return 0; \
__blst_platform_cap_initialized = 1;

#if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64)

# if defined(__GNUC__) || defined(__clang__) || defined(__SUNPRO_C)
@@ -30,6 +39,8 @@ __attribute__((constructor))
# endif
int __blst_cpuid(void)
{
__blst_cpuid_run_once();

int info[4], cap = 0;

__cpuidex(info, 0, 0);
@@ -59,6 +70,8 @@ extern unsigned long getauxval(unsigned long type) __attribute__ ((weak));
__attribute__((constructor))
int __blst_cpuid(void)
{
__blst_cpuid_run_once();

int cap = 0;

if (getauxval) {
@@ -74,6 +87,8 @@ int __blst_cpuid(void)
__attribute__((constructor))
int __blst_cpuid()
{
__blst_cpuid_run_once();

__blst_platform_cap = 1; /* SHA256 */
return 0;
}
3 changes: 3 additions & 0 deletions src/cpuid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
int __blst_platform_cap;

int __blst_cpuid(void);
94 changes: 94 additions & 0 deletions src/ifunc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include "cpuid.h"

#ifdef __GNUC__
/**
* The resolver defined below will run early during load, when some features
* like the stack protector may not be fully set up yet.
*/
# define no_stack_protector \
__attribute__((__optimize__("-fno-stack-protector")))
#else
# define no_stack_protector
#endif

#ifdef __BLST_DYNAMIC_DEBUG__
# include <stdio.h>
# define ifunc_resolver_debug(s) puts(s)
#else
# define ifunc_resolver_debug(s)
#endif

/**
* Chooses whether to use `portable_fn` or `optimized_fn` at runtime depending
* on whether ADX is available or not.
*/
#define ifunc_resolver(fn, portable_fn, optimized_fn) \
no_stack_protector \
static fn##_func_t *resolve_##fn(void) { \
__blst_cpuid(); \
if (__blst_platform_cap & 1) { \
ifunc_resolver_debug("optimized: " #fn " -> " #optimized_fn); \
return optimized_fn; \
} else { \
ifunc_resolver_debug("portable: " #fn " -> " #portable_fn); \
return portable_fn; \
} \
}

/**
* Defines an "indirect function" (ifunc) which is dynamically resolved when
* blst is loaded depending on whether the ADX instruction set is supported or
* not.
*
* Example:
*
* ifunc(dynamic_fn, portable_fn, optimized_fn,
* int, short x, long y, char z);
*
* This example would (roughly) generate the following declarations:
*
* int dynamic_fn(short x, long y, char z);
* int portable_fn(short x, long y, char z);
* int optimized_fn(short x, long y, char z);
*
* The special symbol `dynamic_fn` will be assigned to either `portable_fn` or
* `optimized_fn` at load time, and can be called at low cost at runtime.
*/
#if defined(__GNUC__) && defined(__ELF__)
/* On GCC/clang using the ELF standard; use `__attribute__((ifunc))` */
# define ifunc(fn, portable_fn, optimized_fn, return_type, ...) \
typedef return_type (fn##_func_t)(__VA_ARGS__); \
return_type fn(__VA_ARGS__); \
return_type portable_fn(__VA_ARGS__); \
return_type optimized_fn(__VA_ARGS__); \
ifunc_resolver(fn, portable_fn, optimized_fn); \
return_type fn(__VA_ARGS__) __attribute__((ifunc("resolve_" #fn)));
#elif defined(__GNUC__)
/* On GCC/clang with a generic loader; use function pointers and
* `__attribute__((constructor))` */
# define ifunc(fn, portable_fn, optimized_fn, return_type, ...) \
typedef return_type (fn##_func_t)(__VA_ARGS__); \
return_type (*fn)(__VA_ARGS__); \
return_type portable_fn(__VA_ARGS__); \
return_type optimized_fn(__VA_ARGS__); \
ifunc_resolver(fn, portable_fn, optimized_fn); \
__attribute__((constructor)) \
no_stack_protector \
static void resolve_and_store_##fn(void) { \
fn = resolve_##fn(); \
}
#elif defined(_MSC_VER)
/* On MSVC; use function pointers and add an entry to the `.CRT$XCU` section */
# pragma section(".CRT$XCU",read)
# define ifunc(fn, portable_fn, optimized_fn, return_type, ...) \
typedef return_type (fn##_func_t)(__VA_ARGS__); \
return_type (*fn)(__VA_ARGS__); \
return_type portable_fn(__VA_ARGS__); \
return_type optimized_fn(__VA_ARGS__); \
ifunc_resolver(fn, portable_fn, optimized_fn); \
static void resolve_and_store_##fn(void) { \
fn = resolve_##fn(); \
} \
__declspec(allocate(".CRT$XCU")) static void \
(*__resolve_and_store_##fn)(void) = resolve_and_store_##fn;
#endif
144 changes: 103 additions & 41 deletions src/vect.h
Original file line number Diff line number Diff line change
@@ -66,11 +66,105 @@ typedef byte pow256[256/8];
*/
typedef limb_t bool_t;

#if defined(__BLST_DYNAMIC__)
# if defined(__BLST_PORTABLE__)
# error "__BLST_DYNAMIC__ and __BLST_PORTABLE__ cannot be specified at the same time"
# endif
# if defined(__BLST_NO_ASM__)
# error "__BLST_DYNAMIC__ and __BLST_NO_ASM__ cannot be specified at the same time"
# endif
# if !defined(__ADX__)
# error "__BLST_DYNAMIC__ requires __ADX__"
# endif
#endif

#if defined(__BLST_DYNAMIC__)
# include "ifunc.h"
# define declare_optimizable_func(dyn_fn, portable_fn, optimized_fn, return_type, ...) \
ifunc(dyn_fn, portable_fn, optimized_fn, return_type, __VA_ARGS__);
#elif defined(__ADX__) && !defined(__BLST_NO_ASM__) && !defined(__BLST_PORTABLE__)
# define declare_optimizable_func(dyn_fn, portable_fn, optimized_fn, return_type, ...) \
return_type optimized_fn(__VA_ARGS__);
#elif !defined(__BLST_NO_ASM__)
# define declare_optimizable_func(dyn_fn, portable_fn, optimized_fn, return_type, ...) \
return_type portable_fn(__VA_ARGS__);
#else
# define declare_optimizable_func(dyn_fn, portable_fn, optimized_fn, return_type, ...) \
return_type portable_fn(__VA_ARGS__);
#endif

/*
* Assembly subroutines...
*/
#if defined(__ADX__) /* e.g. -march=broadwell */ && !defined(__BLST_PORTABLE__)\
&& !defined(__BLST_NO_ASM__)
declare_optimizable_func(dyn_mul_mont_sparse_256, mul_mont_sparse_256, mulx_mont_sparse_256,
void, vec256 ret, const vec256 a, const vec256 b, const vec256 p, limb_t n0);
declare_optimizable_func(dyn_sqr_mont_sparse_256, sqr_mont_sparse_256, sqrx_mont_sparse_256,
void, vec256 ret, const vec256 a, const vec256 p, limb_t n0);
declare_optimizable_func(dyn_from_mont_256, from_mont_256, fromx_mont_256,
void, vec256 ret, const vec256 a, const vec256 p, limb_t n0);
declare_optimizable_func(dyn_redc_mont_256, redc_mont_256, redcx_mont_256,
void, vec256 ret, const vec512 a, const vec256 p, limb_t n0);

declare_optimizable_func(dyn_mul_mont_384, mul_mont_384, mulx_mont_384,
void, vec384 ret, const vec384 a, const vec384 b, const vec384 p, limb_t n0);
declare_optimizable_func(dyn_sqr_mont_384, sqr_mont_384, sqrx_mont_384,
void, vec384 ret, const vec384 a, const vec384 p, limb_t n0);
declare_optimizable_func(dyn_sqr_n_mul_mont_384, sqr_n_mul_mont_384, sqrx_n_mul_mont_384,
void, vec384 ret, const vec384 a, size_t count, const vec384 p, limb_t n0, const vec384 b);
declare_optimizable_func(dyn_sqr_n_mul_mont_383, sqr_n_mul_mont_383, sqrx_n_mul_mont_383,
void, vec384 ret, const vec384 a, size_t count, const vec384 p, limb_t n0, const vec384 b);

declare_optimizable_func(dyn_mul_384, mul_384, mulx_384,
void, vec768 ret, const vec384 a, const vec384 b);
declare_optimizable_func(dyn_sqr_384, sqr_384, sqrx_384,
void, vec768 ret, const vec384 a);
declare_optimizable_func(dyn_redc_mont_384, redc_mont_384, redcx_mont_384,
void, vec384 ret, const vec768 a, const vec384 p, limb_t n0);
declare_optimizable_func(dyn_from_mont_384, from_mont_384, fromx_mont_384,
void, vec384 ret, const vec384 a, const vec384 p, limb_t n0);
declare_optimizable_func(dyn_sgn0_pty_mont_384, sgn0_pty_mont_384, sgn0x_pty_mont_384,
limb_t, const vec384 a, const vec384 p, limb_t n0);
declare_optimizable_func(dyn_sgn0_pty_mont_384x, sgn0_pty_mont_384x, sgn0x_pty_mont_384x,
limb_t, const vec384x a, const vec384 p, limb_t n0);

declare_optimizable_func(dyn_ct_inverse_mod_383, ct_inverse_mod_383, ctx_inverse_mod_383,
void, vec768 ret, const vec384 inp, const vec384 mod, const vec384 modx);

declare_optimizable_func(dyn_mul_mont_384x, mul_mont_384x, mulx_mont_384x,
void, vec384x ret, const vec384x a, const vec384x b, const vec384 p, limb_t n0);
declare_optimizable_func(dyn_sqr_mont_384x, sqr_mont_384x, sqrx_mont_384x,
void, vec384x ret, const vec384x a, const vec384 p, limb_t n0);
declare_optimizable_func(dyn_sqr_mont_382x, sqr_mont_382x, sqrx_mont_382x,
void, vec384x ret, const vec384x a, const vec384 p, limb_t n0);
declare_optimizable_func(dyn_mul_382x, mul_382x, mulx_382x,
void, vec768 ret[2], const vec384x a, const vec384x b, const vec384 p);
declare_optimizable_func(dyn_sqr_382x, sqr_382x, sqrx_382x,
void, vec768 ret[2], const vec384x a, const vec384 p);

#if defined(__BLST_DYNAMIC__)
/* Use indirect functions */
# define mul_mont_sparse_256 dyn_mul_mont_sparse_256
# define sqr_mont_sparse_256 dyn_sqr_mont_sparse_256
# define from_mont_256 dyn_from_mont_256
# define redc_mont_256 dyn_redc_mont_256
# define mul_mont_384 dyn_mul_mont_384
# define sqr_mont_384 dyn_sqr_mont_384
# define sqr_n_mul_mont_384 dyn_sqr_n_mul_mont_384
# define sqr_n_mul_mont_383 dyn_sqr_n_mul_mont_383
# define mul_384 dyn_mul_384
# define sqr_384 dyn_sqr_384
# define redc_mont_384 dyn_redc_mont_384
# define from_mont_384 dyn_from_mont_384
# define sgn0_pty_mont_384 dyn_sgn0_pty_mont_384
# define sgn0_pty_mont_384x dyn_sgn0_pty_mont_384x
# define ct_inverse_mod_383 dyn_ct_inverse_mod_383
# define mul_mont_384x dyn_mul_mont_384x
# define sqr_mont_384x dyn_sqr_mont_384x
# define sqr_mont_382x dyn_sqr_mont_382x
# define mul_382x dyn_mul_382x
# define sqr_382x dyn_sqr_382x
#elif defined(__ADX__) && !defined(__BLST_NO_ASM__) && !defined(__BLST_PORTABLE__)
/* Use optimized functions */
# define mul_mont_sparse_256 mulx_mont_sparse_256
# define sqr_mont_sparse_256 sqrx_mont_sparse_256
# define from_mont_256 fromx_mont_256
@@ -86,16 +180,15 @@ typedef limb_t bool_t;
# define sgn0_pty_mont_384 sgn0x_pty_mont_384
# define sgn0_pty_mont_384x sgn0x_pty_mont_384x
# define ct_inverse_mod_383 ctx_inverse_mod_383
#elif defined(__BLST_NO_ASM__)
# define ct_inverse_mod_383 ct_inverse_mod_384
# define mul_mont_384x mulx_mont_384x
# define sqr_mont_384x sqrx_mont_384x
# define sqr_mont_382x sqrx_mont_382x
# define mul_382x mulx_382x
# define sqr_382x sqrx_382x
#else
/* Use portable functions */
#endif

void mul_mont_sparse_256(vec256 ret, const vec256 a, const vec256 b,
const vec256 p, limb_t n0);
void sqr_mont_sparse_256(vec256 ret, const vec256 a, const vec256 p, limb_t n0);
void redc_mont_256(vec256 ret, const vec512 a, const vec256 p, limb_t n0);
void from_mont_256(vec256 ret, const vec256 a, const vec256 p, limb_t n0);

void add_mod_256(vec256 ret, const vec256 a, const vec256 b, const vec256 p);
void sub_mod_256(vec256 ret, const vec256 a, const vec256 b, const vec256 p);
void mul_by_3_mod_256(vec256 ret, const vec256 a, const vec256 p);
@@ -112,20 +205,6 @@ limb_t sub_n_check_mod_256(pow256 ret, const pow256 a, const pow256 b,

void vec_prefetch(const void *ptr, size_t len);

void mul_mont_384(vec384 ret, const vec384 a, const vec384 b,
const vec384 p, limb_t n0);
void sqr_mont_384(vec384 ret, const vec384 a, const vec384 p, limb_t n0);
void sqr_n_mul_mont_384(vec384 ret, const vec384 a, size_t count,
const vec384 p, limb_t n0, const vec384 b);
void sqr_n_mul_mont_383(vec384 ret, const vec384 a, size_t count,
const vec384 p, limb_t n0, const vec384 b);

void mul_384(vec768 ret, const vec384 a, const vec384 b);
void sqr_384(vec768 ret, const vec384 a);
void redc_mont_384(vec384 ret, const vec768 a, const vec384 p, limb_t n0);
void from_mont_384(vec384 ret, const vec384 a, const vec384 p, limb_t n0);
limb_t sgn0_pty_mont_384(const vec384 a, const vec384 p, limb_t n0);
limb_t sgn0_pty_mont_384x(const vec384x a, const vec384 p, limb_t n0);
limb_t sgn0_pty_mod_384(const vec384 a, const vec384 p);
limb_t sgn0_pty_mod_384x(const vec384x a, const vec384 p);

@@ -137,27 +216,10 @@ void cneg_mod_384(vec384 ret, const vec384 a, bool_t flag, const vec384 p);
void lshift_mod_384(vec384 ret, const vec384 a, size_t count, const vec384 p);
void rshift_mod_384(vec384 ret, const vec384 a, size_t count, const vec384 p);
void div_by_2_mod_384(vec384 ret, const vec384 a, const vec384 p);
void ct_inverse_mod_383(vec768 ret, const vec384 inp, const vec384 mod,
const vec384 modx);
void ct_inverse_mod_256(vec512 ret, const vec256 inp, const vec256 mod,
const vec256 modx);
bool_t ct_is_square_mod_384(const vec384 inp, const vec384 mod);

#if defined(__ADX__) /* e.g. -march=broadwell */ && !defined(__BLST_PORTABLE__)
# define mul_mont_384x mulx_mont_384x
# define sqr_mont_384x sqrx_mont_384x
# define sqr_mont_382x sqrx_mont_382x
# define mul_382x mulx_382x
# define sqr_382x sqrx_382x
#endif

void mul_mont_384x(vec384x ret, const vec384x a, const vec384x b,
const vec384 p, limb_t n0);
void sqr_mont_384x(vec384x ret, const vec384x a, const vec384 p, limb_t n0);
void sqr_mont_382x(vec384x ret, const vec384x a, const vec384 p, limb_t n0);
void mul_382x(vec768 ret[2], const vec384x a, const vec384x b, const vec384 p);
void sqr_382x(vec768 ret[2], const vec384x a, const vec384 p);

void add_mod_384x(vec384x ret, const vec384x a, const vec384x b,
const vec384 p);
void sub_mod_384x(vec384x ret, const vec384x a, const vec384x b,