Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ void cleanup_cuda_integer_compress_radix_ciphertext_128(CudaStreamsFFI streams,

void cleanup_cuda_integer_decompress_radix_ciphertext_128(
CudaStreamsFFI streams, int8_t **mem_ptr_void);

void cuda_integer_extract_glwe_128(
CudaStreamsFFI streams, void *glwe_array_out,
CudaPackedGlweCiphertextListFFI const *glwe_list,
uint32_t const glwe_index);

void cuda_integer_extract_glwe_64(
CudaStreamsFFI streams, void *glwe_array_out,
CudaPackedGlweCiphertextListFFI const *glwe_list,
uint32_t const glwe_index);
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,24 @@ void cleanup_cuda_integer_decompress_radix_ciphertext_128(
delete mem_ptr;
*mem_ptr_void = nullptr;
}

void cuda_integer_extract_glwe_128(
CudaStreamsFFI streams, void *glwe_array_out,
CudaPackedGlweCiphertextListFFI const *glwe_list,
uint32_t const glwe_index) {

CudaStreams _streams = CudaStreams(streams);
host_extract<__uint128_t>(_streams.stream(0), _streams.gpu_index(0),
(__uint128_t *)glwe_array_out, glwe_list,
glwe_index);
}

void cuda_integer_extract_glwe_64(
CudaStreamsFFI streams, void *glwe_array_out,
CudaPackedGlweCiphertextListFFI const *glwe_list,
uint32_t const glwe_index) {

CudaStreams _streams = CudaStreams(streams);
host_extract<__uint64_t>(_streams.stream(0), _streams.gpu_index(0),
(__uint64_t *)glwe_array_out, glwe_list, glwe_index);
}
16 changes: 16 additions & 0 deletions backends/tfhe-cuda-backend/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2349,6 +2349,22 @@ unsafe extern "C" {
mem_ptr_void: *mut *mut i8,
);
}
unsafe extern "C" {
pub fn cuda_integer_extract_glwe_128(
streams: CudaStreamsFFI,
glwe_array_out: *mut ffi::c_void,
glwe_list: *const CudaPackedGlweCiphertextListFFI,
glwe_index: u32,
);
}
unsafe extern "C" {
pub fn cuda_integer_extract_glwe_64(
streams: CudaStreamsFFI,
glwe_array_out: *mut ffi::c_void,
glwe_list: *const CudaPackedGlweCiphertextListFFI,
glwe_index: u32,
);
}
unsafe extern "C" {
pub fn scratch_cuda_rerand_64(
streams: CudaStreamsFFI,
Expand Down
28 changes: 27 additions & 1 deletion tfhe/src/integer/gpu/list_compression/server_keys.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::core_crypto::gpu::entities::lwe_packing_keyswitch_key::CudaLwePackingKeyswitchKey;
use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::vec::CudaVec;
use crate::core_crypto::gpu::CudaStreams;
Expand All @@ -16,7 +17,8 @@ use crate::integer::gpu::ciphertext::CudaRadixCiphertext;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{
cuda_backend_compress, cuda_backend_decompress, cuda_backend_get_compression_size_on_gpu,
cuda_backend_get_decompression_size_on_gpu, cuda_memcpy_async_gpu_to_gpu, PBSType,
cuda_backend_get_decompression_size_on_gpu, cuda_memcpy_async_gpu_to_gpu, extract_glwe_async,
PBSType,
};
use crate::prelude::CastInto;
use crate::shortint::ciphertext::{
Expand Down Expand Up @@ -197,6 +199,30 @@ impl<T: UnsignedInteger> CudaPackedGlweCiphertextList<T> {
meta: self.meta,
}
}
pub fn extract_glwe(
&self,
glwe_index: usize,
streams: &CudaStreams,
) -> CudaGlweCiphertextList<T> {
let meta = self
.meta
.as_ref()
.expect("CudaPackedGlweCiphertextList meta must be set to extract GLWE");

let mut output_cuda_glwe_list = CudaGlweCiphertextList::new(
meta.glwe_dimension,
meta.polynomial_size,
GlweCiphertextCount(1),
meta.ciphertext_modulus,
streams,
);

unsafe {
extract_glwe_async(streams, &mut output_cuda_glwe_list, self, glwe_index as u32);
}
streams.synchronize();
output_cuda_glwe_list
}
}

impl<T: UnsignedInteger> Clone for CudaPackedGlweCiphertextList<T> {
Expand Down
42 changes: 42 additions & 0 deletions tfhe/src/integer/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod server_key;
#[cfg(feature = "zk-pok")]
pub mod zk;

use crate::core_crypto::gpu::glwe_ciphertext_list::CudaGlweCiphertextList;
use crate::core_crypto::gpu::lwe_bootstrap_key::CudaModulusSwitchNoiseReductionConfiguration;
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
use crate::core_crypto::gpu::lwe_compact_ciphertext_list::CudaLweCompactCiphertextList;
Expand Down Expand Up @@ -10423,3 +10424,44 @@ pub unsafe fn unchecked_small_scalar_mul_integer_async(
carry_modulus.0 as u32,
);
}
#[allow(clippy::too_many_arguments)]
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as synchronization
/// is required
pub unsafe fn extract_glwe_async<T: UnsignedInteger>(
streams: &CudaStreams,
glwe_array_out: &mut CudaGlweCiphertextList<T>,
glwe_list: &CudaPackedGlweCiphertextList<T>,
glwe_index: u32,
) {
assert_eq!(
streams.gpu_indexes[0],
glwe_array_out.0.d_vec.gpu_index(0),
"GPU error: all data should reside on the same GPU."
);
assert_eq!(
streams.gpu_indexes[0],
glwe_list.data.gpu_index(0),
"GPU error: all data should reside on the same GPU."
);
let packed_glwe_list_ffi = prepare_cuda_packed_glwe_ct_ffi(glwe_list);

if T::BITS == 128 {
cuda_integer_extract_glwe_128(
streams.ffi(),
glwe_array_out.0.d_vec.as_mut_c_ptr(0),
&raw const packed_glwe_list_ffi,
glwe_index,
);
} else if T::BITS == 64 {
cuda_integer_extract_glwe_64(
streams.ffi(),
glwe_array_out.0.d_vec.as_mut_c_ptr(0),
&raw const packed_glwe_list_ffi,
glwe_index,
);
} else {
panic!("Unsupported integer size for CUDA GLWE extraction");
}
}
Loading
Loading