@@ -150,8 +150,7 @@ class DynamicVamanaIndexImpl {
150150 }
151151 };
152152
153- auto threadpool =
154- svs::threads::OMPThreadPool (std::min (n, size_t (omp_get_max_threads ())));
153+ auto threadpool = default_threadpool ();
155154
156155 svs::threads::parallel_for (
157156 threadpool, svs::threads::StaticPartition{n}, search_closure
@@ -237,8 +236,7 @@ class DynamicVamanaIndexImpl {
237236 }
238237 };
239238
240- auto threadpool =
241- svs::threads::OMPThreadPool (std::min (n, size_t (omp_get_max_threads ())));
239+ auto threadpool = default_threadpool ();
242240
243241 svs::threads::parallel_for (
244242 threadpool, svs::threads::StaticPartition{n}, range_search_closure
@@ -367,23 +365,28 @@ class DynamicVamanaIndexImpl {
367365 );
368366 }
369367
370- template <typename Tag>
371- svs::DynamicVamana* init_impl_t (
368+ template <typename Tag, typename ... StorageArgs >
369+ static svs::DynamicVamana* build_impl (
372370 Tag&& tag,
373371 MetricType metric,
372+ const index::vamana::VamanaBuildParameters& parameters,
374373 const svs::data::ConstSimpleDataView<float >& data,
375- std::span<const size_t > labels
374+ std::span<const size_t > labels,
375+ StorageArgs&&... storage_args
376376 ) {
377- auto threadpool =
378- svs::threads::ThreadPoolHandle (svs::threads::OMPThreadPool (omp_get_max_threads ()
379- ));
377+ auto threadpool = default_threadpool ();
380378
381- auto storage = make_storage (std::forward<Tag>(tag), data, threadpool);
379+ auto storage = make_storage (
380+ std::forward<Tag>(tag),
381+ data,
382+ threadpool,
383+ std::forward<StorageArgs>(storage_args)...
384+ );
382385
383386 svs::DistanceDispatcher distance_dispatcher (to_svs_distance (metric));
384387 return distance_dispatcher ([&](auto && distance) {
385388 return new svs::DynamicVamana (svs::DynamicVamana::build<float >(
386- this -> vamana_build_parameters () ,
389+ parameters ,
387390 std::move (storage),
388391 std::move (labels),
389392 std::forward<decltype (distance)>(distance),
@@ -398,19 +401,24 @@ class DynamicVamanaIndexImpl {
398401 get_storage_kind (),
399402 [this ](
400403 auto && tag,
401- MetricType metric,
402404 data::ConstSimpleDataView<float > data,
403405 std::span<const size_t > labels
404406 ) {
405407 using Tag = std::decay_t <decltype (tag)>;
406- return init_impl_t (std::forward<Tag>(tag), metric, data, labels);
408+ return build_impl (
409+ std::forward<Tag>(tag),
410+ this ->metric_type_ ,
411+ this ->vamana_build_parameters (),
412+ data,
413+ labels
414+ );
407415 },
408- metric_type_,
409416 data,
410417 labels
411418 ));
412419 }
413420
421+ // Constructor used during loading
414422 DynamicVamanaIndexImpl (
415423 std::unique_ptr<svs::DynamicVamana>&& impl,
416424 MetricType metric,
@@ -432,21 +440,20 @@ class DynamicVamanaIndexImpl {
432440 impl_->get_full_search_history ()};
433441 }
434442
435- template <typename StorageTag>
436- static svs::DynamicVamana* deserialize_impl_t (
437- StorageTag&& SVS_UNUSED(tag), std::istream& stream, MetricType metric
438- ) {
439- auto threadpool =
440- svs::threads::ThreadPoolHandle (svs::threads::OMPThreadPool (omp_get_max_threads ()
441- ));
443+ template <storage::StorageTag Tag>
444+ static svs::DynamicVamana*
445+ load_impl_t (Tag&& SVS_UNUSED(tag), std::istream& stream, MetricType metric) {
446+ auto threadpool = default_threadpool ();
442447
443448 svs::DistanceDispatcher distance_dispatcher (to_svs_distance (metric));
444449 return distance_dispatcher ([&](auto && distance) {
445- return new svs::DynamicVamana (svs::DynamicVamana::assemble<
446- float ,
447- storage::StorageType_t<StorageTag::value>>(
448- stream, std::forward<decltype (distance)>(distance), std::move (threadpool)
449- ));
450+ return new svs::DynamicVamana (
451+ svs::DynamicVamana::assemble<float , storage::StorageType_t<Tag>>(
452+ stream,
453+ std::forward<decltype (distance)>(distance),
454+ std::move (threadpool)
455+ )
456+ );
450457 });
451458 }
452459
@@ -458,7 +465,7 @@ class DynamicVamanaIndexImpl {
458465 [&](auto && tag, std::istream& stream, MetricType metric) {
459466 using Tag = std::decay_t <decltype (tag)>;
460467 std::unique_ptr<svs::DynamicVamana> impl{
461- deserialize_impl_t (std::forward<Tag>(tag), stream, metric)};
468+ load_impl_t (std::forward<Tag>(tag), stream, metric)};
462469
463470 return new DynamicVamanaIndexImpl (std::move (impl), metric, storage_kind);
464471 },
@@ -478,9 +485,19 @@ class DynamicVamanaIndexImpl {
478485 size_t ntotal_soft_deleted{0 };
479486};
480487
481- struct DynamicVamanaIndexLeanVecImplTrained : public DynamicVamanaIndexImpl {
482- using DynamicVamanaIndexImpl::DynamicVamanaIndexImpl;
483- DynamicVamanaIndexLeanVecImplTrained (
488+ struct DynamicVamanaIndexLeanVecImpl : public DynamicVamanaIndexImpl {
489+ DynamicVamanaIndexLeanVecImpl (
490+ std::unique_ptr<svs::DynamicVamana>&& impl,
491+ MetricType metric,
492+ StorageKind storage_kind
493+ )
494+ : DynamicVamanaIndexImpl{std::move (impl), metric, storage_kind}
495+ , leanvec_dims_{0 }
496+ , leanvec_matrices_{std::nullopt } {
497+ check_storage_kind (storage_kind);
498+ }
499+
500+ DynamicVamanaIndexLeanVecImpl (
484501 size_t dim,
485502 MetricType metric,
486503 StorageKind storage_kind,
@@ -489,82 +506,12 @@ struct DynamicVamanaIndexLeanVecImplTrained : public DynamicVamanaIndexImpl {
489506 const VamanaIndex::SearchParams& default_search_params = {10 , 10 }
490507 )
491508 : DynamicVamanaIndexImpl{dim, metric, storage_kind, params, default_search_params}
492- , training_data_{training_data} {}
493-
494- template <typename Tag>
495- svs::DynamicVamana* init_impl_t (
496- Tag&& tag,
497- MetricType metric,
498- const svs::data::ConstSimpleDataView<float >& data,
499- std::span<const size_t > labels,
500- std::optional<LeanVecMatricesType> matrices
501- ) {
502- auto threadpool =
503- svs::threads::ThreadPoolHandle (svs::threads::OMPThreadPool (omp_get_max_threads ()
504- ));
505-
506- auto storage =
507- make_storage (std::forward<Tag>(tag), data, threadpool, 0 , std::move (matrices));
508-
509- svs::DistanceDispatcher distance_dispatcher (to_svs_distance (metric));
510- return distance_dispatcher ([&](auto && distance) {
511- return new svs::DynamicVamana (svs::DynamicVamana::build<float >(
512- this ->vamana_build_parameters (),
513- std::move (storage),
514- std::move (labels),
515- std::forward<decltype (distance)>(distance),
516- std::move (threadpool)
517- ));
518- });
509+ , leanvec_dims_{training_data.get_leanvec_dims ()}
510+ , leanvec_matrices_{training_data.get_leanvec_matrices ()} {
511+ check_storage_kind (storage_kind);
519512 }
520513
521- template <typename F, typename ... Args>
522- static auto dispatch_storage_kind (StorageKind kind, F&& f, Args&&... args) {
523- using SK = StorageKind;
524- using namespace svs ::runtime::storage;
525- switch (kind) {
526- case SK::LeanVec4x4:
527- return f (LeanVec4x4Tag{}, std::forward<Args>(args)...);
528- case SK::LeanVec4x8:
529- return f (LeanVec4x8Tag{}, std::forward<Args>(args)...);
530- case SK::LeanVec8x8:
531- return f (LeanVec8x8Tag{}, std::forward<Args>(args)...);
532- default :
533- throw ANNEXCEPTION (" not supported SVS leanvec storage kind" );
534- }
535- }
536-
537- void init_impl (data::ConstSimpleDataView<float > data, std::span<const size_t > labels)
538- override {
539- impl_.reset (DynamicVamanaIndexLeanVecImplTrained::dispatch_storage_kind (
540- this ->storage_kind_ ,
541- [this ](
542- auto && tag,
543- MetricType metric,
544- data::ConstSimpleDataView<float > data,
545- std::span<const size_t > labels
546- ) {
547- using Tag = std::decay_t <decltype (tag)>;
548- return DynamicVamanaIndexLeanVecImplTrained::init_impl_t (
549- std::forward<Tag>(tag),
550- metric,
551- data,
552- labels,
553- training_data_.get_leanvec_matrices ()
554- );
555- },
556- metric_type_,
557- data,
558- labels
559- ));
560- }
561-
562- LeanVecTrainingDataImpl training_data_;
563- };
564-
565- struct DynamicVamanaIndexLeanVecImplDims : public DynamicVamanaIndexImpl {
566- using DynamicVamanaIndexImpl::DynamicVamanaIndexImpl;
567- DynamicVamanaIndexLeanVecImplDims (
514+ DynamicVamanaIndexLeanVecImpl (
568515 size_t dim,
569516 MetricType metric,
570517 StorageKind storage_kind,
@@ -573,72 +520,70 @@ struct DynamicVamanaIndexLeanVecImplDims : public DynamicVamanaIndexImpl {
573520 const VamanaIndex::SearchParams& default_search_params = {10 , 10 }
574521 )
575522 : DynamicVamanaIndexImpl{dim, metric, storage_kind, params, default_search_params}
576- , leanvec_dims_{leanvec_dims} {}
577-
578- template <typename Tag>
579- svs::DynamicVamana* init_impl_t (
580- Tag&& tag,
581- MetricType metric,
582- const svs::data::ConstSimpleDataView<float >& data,
583- std::span<const size_t > labels,
584- size_t leanvec_dims
585- ) {
586- auto threadpool =
587- svs::threads::ThreadPoolHandle (svs::threads::OMPThreadPool (omp_get_max_threads ()
588- ));
589-
590- auto storage = make_storage (std::forward<Tag>(tag), data, threadpool, leanvec_dims);
591-
592- svs::DistanceDispatcher distance_dispatcher (to_svs_distance (metric));
593- return distance_dispatcher ([&](auto && distance) {
594- return new svs::DynamicVamana (svs::DynamicVamana::build<float >(
595- this ->vamana_build_parameters (),
596- std::move (storage),
597- std::move (labels),
598- std::forward<decltype (distance)>(distance),
599- std::move (threadpool)
600- ));
601- });
523+ , leanvec_dims_{leanvec_dims}
524+ , leanvec_matrices_{std::nullopt } {
525+ check_storage_kind (storage_kind);
602526 }
603527
604528 template <typename F, typename ... Args>
605- static auto dispatch_storage_kind (StorageKind kind, F&& f, Args&&... args) {
606- using SK = StorageKind;
607- using namespace svs ::runtime::storage;
529+ static auto dispatch_leanvec_storage_kind (StorageKind kind, F&& f, Args&&... args) {
608530 switch (kind) {
609- case SK ::LeanVec4x4:
610- return f (LeanVec4x4Tag{}, std::forward<Args>(args)...);
611- case SK ::LeanVec4x8:
612- return f (LeanVec4x8Tag{}, std::forward<Args>(args)...);
613- case SK ::LeanVec8x8:
614- return f (LeanVec8x8Tag{}, std::forward<Args>(args)...);
531+ case StorageKind ::LeanVec4x4:
532+ return f (storage:: LeanVec4x4Tag{}, std::forward<Args>(args)...);
533+ case StorageKind ::LeanVec4x8:
534+ return f (storage:: LeanVec4x8Tag{}, std::forward<Args>(args)...);
535+ case StorageKind ::LeanVec8x8:
536+ return f (storage:: LeanVec8x8Tag{}, std::forward<Args>(args)...);
615537 default :
616- throw ANNEXCEPTION (" not supported SVS leanvec storage kind" );
538+ throw StatusException{
539+ ErrorCode::INVALID_ARGUMENT, " SVS LeanVec storage kind required" };
617540 }
618541 }
619542
620543 void init_impl (data::ConstSimpleDataView<float > data, std::span<const size_t > labels)
621544 override {
622- impl_.reset (DynamicVamanaIndexLeanVecImplDims::dispatch_storage_kind (
545+ assert (storage::is_leanvec_storage (this ->storage_kind_ ));
546+ impl_.reset (dispatch_leanvec_storage_kind (
623547 this ->storage_kind_ ,
624548 [this ](
625549 auto && tag,
626- MetricType metric,
627550 data::ConstSimpleDataView<float > data,
628551 std::span<const size_t > labels
629552 ) {
630553 using Tag = std::decay_t <decltype (tag)>;
631- return DynamicVamanaIndexLeanVecImplDims::init_impl_t (
632- std::forward<Tag>(tag), metric, data, labels, leanvec_dims_
554+ return DynamicVamanaIndexImpl::build_impl (
555+ std::forward<Tag>(tag),
556+ this ->metric_type_ ,
557+ this ->vamana_build_parameters (),
558+ data,
559+ labels,
560+ this ->leanvec_dims_ ,
561+ this ->leanvec_matrices_
633562 );
634563 },
635- metric_type_,
636564 data,
637565 labels
638566 ));
639567 }
640568
569+ protected:
641570 size_t leanvec_dims_;
571+ std::optional<LeanVecMatricesType> leanvec_matrices_;
572+
573+ StorageKind check_storage_kind (StorageKind kind) {
574+ if (!storage::is_leanvec_storage (kind)) {
575+ throw StatusException (
576+ ErrorCode::INVALID_ARGUMENT, " SVS LeanVec storage kind required"
577+ );
578+ }
579+ if (!svs::detail::lvq_leanvec_enabled ()) {
580+ throw StatusException (
581+ ErrorCode::NOT_IMPLEMENTED,
582+ " LeanVec storage kind requested but not supported by CPU"
583+ );
584+ }
585+ return kind;
586+ }
642587};
643588
644589} // namespace runtime
0 commit comments