Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@ template <typename Torus>
bool supports_distributed_shared_memory_on_multibit_programmable_bootstrap(
uint32_t polynomial_size, uint32_t max_shared_memory);

template <typename Torus>
template <typename InputTorus, typename Torus>
bool has_support_to_cuda_programmable_bootstrap_tbc_multi_bit(
uint32_t num_samples, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t max_shared_memory);

#if CUDA_ARCH >= 900
template <typename Torus>
template <typename InputTorus, typename Torus>
uint64_t scratch_cuda_tbc_multi_bit_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, MULTI_BIT> **buffer,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory);

template <typename Torus>
template <typename InputTorus, typename Torus>
void cuda_tbc_multi_bit_programmable_bootstrap_lwe_ciphertext_vector(
void *stream, uint32_t gpu_index, Torus *lwe_array_out,
Torus const *lwe_output_indexes, Torus const *lut_vector,
Torus const *lut_vector_indexes, Torus const *lwe_array_in,
Torus const *lut_vector_indexes, InputTorus const *lwe_array_in,
Torus const *lwe_input_indexes, Torus const *bootstrapping_key,
pbs_buffer<Torus, MULTI_BIT> *pbs_buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t grouping_factor,
Expand All @@ -48,7 +48,7 @@ void cuda_cg_multi_bit_programmable_bootstrap_lwe_ciphertext_vector(
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
uint32_t num_many_lut, uint32_t lut_stride);

template <typename Torus>
template <typename InputTorus, typename Torus>
uint64_t scratch_cuda_multi_bit_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, MULTI_BIT> **pbs_buffer,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
Expand Down Expand Up @@ -96,7 +96,7 @@ template <typename Torus>
uint64_t get_buffer_size_full_sm_tbc_multibit_programmable_bootstrap(
uint32_t polynomial_size);

template <typename Torus, class params>
template <typename InputTorus, typename Torus, class params>
uint64_t get_lwe_chunk_size(uint32_t gpu_index, uint32_t max_num_pbs,
uint32_t polynomial_size, uint32_t glwe_dimension,
uint32_t level_count, uint64_t full_sm_keybundle);
Expand Down
18 changes: 9 additions & 9 deletions backends/tfhe-cuda-backend/cuda/include/pbs/pbs_utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,41 +428,41 @@ uint64_t get_buffer_size_programmable_bootstrap_cg(
return buffer_size + buffer_size % sizeof(double2);
}

template <typename Torus>
template <typename InputTorus, typename Torus>
bool has_support_to_cuda_programmable_bootstrap_cg(uint32_t glwe_dimension,
uint32_t polynomial_size,
uint32_t level_count,
uint32_t num_samples,
uint32_t max_shared_memory);

template <typename Torus>
template <typename InputTorus, typename Torus>
void cuda_programmable_bootstrap_cg_lwe_ciphertext_vector(
void *stream, uint32_t gpu_index, Torus *lwe_array_out,
Torus const *lwe_output_indexes, Torus const *lut_vector,
Torus const *lut_vector_indexes, Torus const *lwe_array_in,
Torus const *lut_vector_indexes, InputTorus const *lwe_array_in,
Torus const *lwe_input_indexes, double2 const *bootstrapping_key,
pbs_buffer<Torus, CLASSICAL> *buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t num_samples, uint32_t num_many_lut,
uint32_t lut_stride);

template <typename Torus>
template <typename InputTorus, typename Torus>
void cuda_programmable_bootstrap_lwe_ciphertext_vector(
void *stream, uint32_t gpu_index, Torus *lwe_array_out,
Torus const *lwe_output_indexes, Torus const *lut_vector,
Torus const *lut_vector_indexes, Torus const *lwe_array_in,
Torus const *lut_vector_indexes, InputTorus const *lwe_array_in,
Torus const *lwe_input_indexes, double2 const *bootstrapping_key,
pbs_buffer<Torus, CLASSICAL> *buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t num_samples, uint32_t num_many_lut,
uint32_t lut_stride);

#if (CUDA_ARCH >= 900)
template <typename Torus>
template <typename InputTorus, typename Torus>
void cuda_programmable_bootstrap_tbc_lwe_ciphertext_vector(
void *stream, uint32_t gpu_index, Torus *lwe_array_out,
Torus const *lwe_output_indexes, Torus const *lut_vector,
Torus const *lut_vector_indexes, Torus const *lwe_array_in,
Torus const *lut_vector_indexes, InputTorus const *lwe_array_in,
Torus const *lwe_input_indexes, double2 const *bootstrapping_key,
pbs_buffer<Torus, CLASSICAL> *buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t base_log,
Expand All @@ -477,14 +477,14 @@ uint64_t scratch_cuda_programmable_bootstrap_tbc(
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
#endif

template <typename Torus>
template <typename InputTorus, typename Torus>
uint64_t scratch_cuda_programmable_bootstrap_cg(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, CLASSICAL> **pbs_buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t level_count, uint32_t input_lwe_ciphertext_count,
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);

template <typename Torus>
template <typename InputTorus, typename Torus>
uint64_t scratch_cuda_programmable_bootstrap(
void *stream, uint32_t gpu_index, pbs_buffer<Torus, CLASSICAL> **buffer,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,19 @@ uint64_t scratch_cuda_programmable_bootstrap_64(
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);

uint64_t scratch_cuda_programmable_bootstrap_32_64(
void *stream, uint32_t gpu_index, int8_t **buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);

uint64_t scratch_cuda_programmable_bootstrap_128(
void *stream, uint32_t gpu_index, int8_t **buffer, uint32_t lwe_dimension,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);

void cuda_programmable_bootstrap_lwe_ciphertext_vector_32(
void cuda_programmable_bootstrap_lwe_ciphertext_vector_64_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lut_vector,
void const *lut_vector_indexes, void const *lwe_array_in,
Expand All @@ -84,7 +90,7 @@ void cuda_programmable_bootstrap_lwe_ciphertext_vector_32(
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_samples, uint32_t num_many_lut, uint32_t lut_stride);

void cuda_programmable_bootstrap_lwe_ciphertext_vector_64(
void cuda_programmable_bootstrap_lwe_ciphertext_vector_32_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lut_vector,
void const *lut_vector_indexes, void const *lwe_array_in,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
#include "pbs_enums.h"
#include "stdint.h"

extern "C" {

bool has_support_to_cuda_programmable_bootstrap_cg_multi_bit(
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t num_samples, uint32_t max_shared_memory);

extern "C" {

void cuda_convert_lwe_multi_bit_programmable_bootstrap_key_64(
void *stream, uint32_t gpu_index, void *dest, void const *src,
uint32_t input_lwe_dim, uint32_t glwe_dim, uint32_t level_count,
Expand All @@ -20,6 +20,11 @@ void cuda_convert_lwe_multi_bit_programmable_bootstrap_key_128(
uint32_t input_lwe_dim, uint32_t glwe_dim, uint32_t level_count,
uint32_t polynomial_size, uint32_t grouping_factor);

uint64_t scratch_cuda_multi_bit_programmable_bootstrap_32_64(
void *stream, uint32_t gpu_index, int8_t **pbs_buffer,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory);

uint64_t scratch_cuda_multi_bit_programmable_bootstrap_64(
void *stream, uint32_t gpu_index, int8_t **pbs_buffer,
uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t level_count,
Expand All @@ -35,6 +40,16 @@ void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_64(
uint32_t level_count, uint32_t num_samples, uint32_t num_many_lut,
uint32_t lut_stride);

void cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_vector_32_64(
void *stream, uint32_t gpu_index, void *lwe_array_out,
void const *lwe_output_indexes, void const *lut_vector,
void const *lut_vector_indexes, void const *lwe_array_in,
void const *lwe_input_indexes, void const *bootstrapping_key,
int8_t *buffer, uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log,
uint32_t level_count, uint32_t num_samples, uint32_t num_many_lut,
uint32_t lut_stride);

void cleanup_cuda_multi_bit_programmable_bootstrap(void *stream,
uint32_t gpu_index,
int8_t **pbs_buffer);
Expand Down
1 change: 0 additions & 1 deletion backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,6 @@ __host__ void host_gemm_keyswitch_lwe_ciphertext_vector(
dim3 grid_negate(CEIL_DIV(lwe_dimension_out + 1, BLOCK_SIZE_DECOMP),
CEIL_DIV(num_samples, BLOCK_SIZE_DECOMP));
dim3 threads_negate(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);

// Negate all outputs in the output LWEs. This is the final step in the GEMM
// keyswitch computed as: -(-b + sum(a_i A_KSK))
keyswitch_negate_with_output_indices<Torus, KSTorus>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,42 +263,7 @@ void execute_pbs_async(CudaStreams streams,
uint32_t num_many_lut, uint32_t lut_stride) {

if constexpr (std::is_same_v<OutputTorus, uint32_t>) {
// 32 bits
switch (pbs_type) {
case MULTI_BIT:
PANIC("Error: 32-bit multibit PBS is not supported.\n")
case CLASSICAL:
for (uint i = 0; i < streams.count(); i++) {
int num_inputs_on_gpu = get_num_inputs_on_gpu(
input_lwe_ciphertext_count, i, streams.count());

int gpu_offset =
get_gpu_offset(input_lwe_ciphertext_count, i, streams.count());
auto d_lut_vector_indexes =
lut_indexes_vec[i] + (ptrdiff_t)(gpu_offset);

// Use the macro to get the correct elements for the current iteration
// Handles the case when the input/output are scattered through
// different gpus and when it is not
auto current_lwe_array_out = get_variant_element(lwe_array_out, i);
auto current_lwe_output_indexes =
get_variant_element(lwe_output_indexes, i);
auto current_lwe_array_in = get_variant_element(lwe_array_in, i);
auto current_lwe_input_indexes =
get_variant_element(lwe_input_indexes, i);

cuda_programmable_bootstrap_lwe_ciphertext_vector_32(
streams.stream(i), streams.gpu_index(i), current_lwe_array_out,
current_lwe_output_indexes, lut_vec[i], d_lut_vector_indexes,
current_lwe_array_in, current_lwe_input_indexes,
bootstrapping_keys[i], pbs_buffer[i], lwe_dimension, glwe_dimension,
polynomial_size, base_log, level_count, num_inputs_on_gpu,
num_many_lut, lut_stride);
}
break;
default:
PANIC("Error: unsupported cuda PBS type.")
}
PANIC("Error: unsupported 32b CUDA PBS type.")
} else if constexpr (std::is_same_v<OutputTorus, uint64_t>) {
// 64 bits
switch (pbs_type) {
Expand Down Expand Up @@ -353,7 +318,7 @@ void execute_pbs_async(CudaStreams streams,
auto d_lut_vector_indexes =
lut_indexes_vec[i] + (ptrdiff_t)(gpu_offset);

cuda_programmable_bootstrap_lwe_ciphertext_vector_64(
cuda_programmable_bootstrap_lwe_ciphertext_vector_64_64(
streams.stream(i), streams.gpu_index(i), current_lwe_array_out,
current_lwe_output_indexes, lut_vec[i], d_lut_vector_indexes,
current_lwe_array_in, current_lwe_input_indexes,
Expand Down
Loading
Loading