Skip to content

Commit ae0f97d

Browse files
committed
Refactored VamanaIndex implementation code
1 parent fdcffe2 commit ae0f97d

File tree

5 files changed

+190
-241
lines changed

5 files changed

+190
-241
lines changed

bindings/cpp/include/vamana_index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
namespace svs {
2323
namespace runtime {
2424

25-
// Abstract interface for Vamana-based indexes.
25+
// Abstract interface for Vamana-based indices.
2626
// NOTE VamanaIndex is not implemented directly, only DynamicVamanaIndex is implemented.
2727
struct SVS_RUNTIME_API VamanaIndex {
2828
virtual ~VamanaIndex() = 0;

bindings/cpp/src/dynamic_vamana_index.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,8 @@ struct DynamicVamanaIndexManagerBase : public DynamicVamanaIndex {
128128
};
129129

130130
using DynamicVamanaIndexManager = DynamicVamanaIndexManagerBase<DynamicVamanaIndexImpl>;
131-
using DynamicVamanaIndexLeanVecImplTrainedManager =
132-
DynamicVamanaIndexManagerBase<DynamicVamanaIndexLeanVecImplTrained>;
133-
using DynamicVamanaIndexLeanVecImplDimsManager =
134-
DynamicVamanaIndexManagerBase<DynamicVamanaIndexLeanVecImplDims>;
131+
using DynamicVamanaIndexLeanVecImplManager =
132+
DynamicVamanaIndexManagerBase<DynamicVamanaIndexLeanVecImpl>;
135133

136134
} // namespace
137135

@@ -187,10 +185,10 @@ Status DynamicVamanaIndexLeanVec::build(
187185
) noexcept {
188186
*index = nullptr;
189187
SVS_RUNTIME_TRY_BEGIN
190-
auto impl = std::make_unique<DynamicVamanaIndexLeanVecImplDims>(
188+
auto impl = std::make_unique<DynamicVamanaIndexLeanVecImpl>(
191189
dim, metric, storage_kind, leanvec_dims, params, default_search_params
192190
);
193-
*index = new DynamicVamanaIndexLeanVecImplDimsManager{std::move(impl)};
191+
*index = new DynamicVamanaIndexLeanVecImplManager{std::move(impl)};
194192
return Status_Ok;
195193
SVS_RUNTIME_TRY_END
196194
}
@@ -209,10 +207,10 @@ Status DynamicVamanaIndexLeanVec::build(
209207
SVS_RUNTIME_TRY_BEGIN
210208
auto training_data_impl =
211209
static_cast<const LeanVecTrainingDataManager*>(training_data)->impl_;
212-
auto impl = std::make_unique<DynamicVamanaIndexLeanVecImplTrained>(
210+
auto impl = std::make_unique<DynamicVamanaIndexLeanVecImpl>(
213211
dim, metric, storage_kind, training_data_impl, params, default_search_params
214212
);
215-
*index = new DynamicVamanaIndexLeanVecImplTrainedManager{std::move(impl)};
213+
*index = new DynamicVamanaIndexLeanVecImplManager{std::move(impl)};
216214
return Status_Ok;
217215
SVS_RUNTIME_TRY_END
218216
}

bindings/cpp/src/dynamic_vamana_index_impl.h

Lines changed: 91 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ class DynamicVamanaIndexImpl {
150150
}
151151
};
152152

153-
auto threadpool =
154-
svs::threads::OMPThreadPool(std::min(n, size_t(omp_get_max_threads())));
153+
auto threadpool = default_threadpool();
155154

156155
svs::threads::parallel_for(
157156
threadpool, svs::threads::StaticPartition{n}, search_closure
@@ -237,8 +236,7 @@ class DynamicVamanaIndexImpl {
237236
}
238237
};
239238

240-
auto threadpool =
241-
svs::threads::OMPThreadPool(std::min(n, size_t(omp_get_max_threads())));
239+
auto threadpool = default_threadpool();
242240

243241
svs::threads::parallel_for(
244242
threadpool, svs::threads::StaticPartition{n}, range_search_closure
@@ -367,23 +365,28 @@ class DynamicVamanaIndexImpl {
367365
);
368366
}
369367

370-
template <typename Tag>
371-
svs::DynamicVamana* init_impl_t(
368+
template <typename Tag, typename... StorageArgs>
369+
static svs::DynamicVamana* build_impl(
372370
Tag&& tag,
373371
MetricType metric,
372+
const index::vamana::VamanaBuildParameters& parameters,
374373
const svs::data::ConstSimpleDataView<float>& data,
375-
std::span<const size_t> labels
374+
std::span<const size_t> labels,
375+
StorageArgs&&... storage_args
376376
) {
377-
auto threadpool =
378-
svs::threads::ThreadPoolHandle(svs::threads::OMPThreadPool(omp_get_max_threads()
379-
));
377+
auto threadpool = default_threadpool();
380378

381-
auto storage = make_storage(std::forward<Tag>(tag), data, threadpool);
379+
auto storage = make_storage(
380+
std::forward<Tag>(tag),
381+
data,
382+
threadpool,
383+
std::forward<StorageArgs>(storage_args)...
384+
);
382385

383386
svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric));
384387
return distance_dispatcher([&](auto&& distance) {
385388
return new svs::DynamicVamana(svs::DynamicVamana::build<float>(
386-
this->vamana_build_parameters(),
389+
parameters,
387390
std::move(storage),
388391
std::move(labels),
389392
std::forward<decltype(distance)>(distance),
@@ -398,19 +401,24 @@ class DynamicVamanaIndexImpl {
398401
get_storage_kind(),
399402
[this](
400403
auto&& tag,
401-
MetricType metric,
402404
data::ConstSimpleDataView<float> data,
403405
std::span<const size_t> labels
404406
) {
405407
using Tag = std::decay_t<decltype(tag)>;
406-
return init_impl_t(std::forward<Tag>(tag), metric, data, labels);
408+
return build_impl(
409+
std::forward<Tag>(tag),
410+
this->metric_type_,
411+
this->vamana_build_parameters(),
412+
data,
413+
labels
414+
);
407415
},
408-
metric_type_,
409416
data,
410417
labels
411418
));
412419
}
413420

421+
// Constructor used during loading
414422
DynamicVamanaIndexImpl(
415423
std::unique_ptr<svs::DynamicVamana>&& impl,
416424
MetricType metric,
@@ -432,21 +440,20 @@ class DynamicVamanaIndexImpl {
432440
impl_->get_full_search_history()};
433441
}
434442

435-
template <typename StorageTag>
436-
static svs::DynamicVamana* deserialize_impl_t(
437-
StorageTag&& SVS_UNUSED(tag), std::istream& stream, MetricType metric
438-
) {
439-
auto threadpool =
440-
svs::threads::ThreadPoolHandle(svs::threads::OMPThreadPool(omp_get_max_threads()
441-
));
443+
template <storage::StorageTag Tag>
444+
static svs::DynamicVamana*
445+
load_impl_t(Tag&& SVS_UNUSED(tag), std::istream& stream, MetricType metric) {
446+
auto threadpool = default_threadpool();
442447

443448
svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric));
444449
return distance_dispatcher([&](auto&& distance) {
445-
return new svs::DynamicVamana(svs::DynamicVamana::assemble<
446-
float,
447-
storage::StorageType_t<StorageTag::value>>(
448-
stream, std::forward<decltype(distance)>(distance), std::move(threadpool)
449-
));
450+
return new svs::DynamicVamana(
451+
svs::DynamicVamana::assemble<float, storage::StorageType_t<Tag>>(
452+
stream,
453+
std::forward<decltype(distance)>(distance),
454+
std::move(threadpool)
455+
)
456+
);
450457
});
451458
}
452459

@@ -458,7 +465,7 @@ class DynamicVamanaIndexImpl {
458465
[&](auto&& tag, std::istream& stream, MetricType metric) {
459466
using Tag = std::decay_t<decltype(tag)>;
460467
std::unique_ptr<svs::DynamicVamana> impl{
461-
deserialize_impl_t(std::forward<Tag>(tag), stream, metric)};
468+
load_impl_t(std::forward<Tag>(tag), stream, metric)};
462469

463470
return new DynamicVamanaIndexImpl(std::move(impl), metric, storage_kind);
464471
},
@@ -478,9 +485,19 @@ class DynamicVamanaIndexImpl {
478485
size_t ntotal_soft_deleted{0};
479486
};
480487

481-
struct DynamicVamanaIndexLeanVecImplTrained : public DynamicVamanaIndexImpl {
482-
using DynamicVamanaIndexImpl::DynamicVamanaIndexImpl;
483-
DynamicVamanaIndexLeanVecImplTrained(
488+
struct DynamicVamanaIndexLeanVecImpl : public DynamicVamanaIndexImpl {
489+
DynamicVamanaIndexLeanVecImpl(
490+
std::unique_ptr<svs::DynamicVamana>&& impl,
491+
MetricType metric,
492+
StorageKind storage_kind
493+
)
494+
: DynamicVamanaIndexImpl{std::move(impl), metric, storage_kind}
495+
, leanvec_dims_{0}
496+
, leanvec_matrices_{std::nullopt} {
497+
check_storage_kind(storage_kind);
498+
}
499+
500+
DynamicVamanaIndexLeanVecImpl(
484501
size_t dim,
485502
MetricType metric,
486503
StorageKind storage_kind,
@@ -489,82 +506,12 @@ struct DynamicVamanaIndexLeanVecImplTrained : public DynamicVamanaIndexImpl {
489506
const VamanaIndex::SearchParams& default_search_params = {10, 10}
490507
)
491508
: DynamicVamanaIndexImpl{dim, metric, storage_kind, params, default_search_params}
492-
, training_data_{training_data} {}
493-
494-
template <typename Tag>
495-
svs::DynamicVamana* init_impl_t(
496-
Tag&& tag,
497-
MetricType metric,
498-
const svs::data::ConstSimpleDataView<float>& data,
499-
std::span<const size_t> labels,
500-
std::optional<LeanVecMatricesType> matrices
501-
) {
502-
auto threadpool =
503-
svs::threads::ThreadPoolHandle(svs::threads::OMPThreadPool(omp_get_max_threads()
504-
));
505-
506-
auto storage =
507-
make_storage(std::forward<Tag>(tag), data, threadpool, 0, std::move(matrices));
508-
509-
svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric));
510-
return distance_dispatcher([&](auto&& distance) {
511-
return new svs::DynamicVamana(svs::DynamicVamana::build<float>(
512-
this->vamana_build_parameters(),
513-
std::move(storage),
514-
std::move(labels),
515-
std::forward<decltype(distance)>(distance),
516-
std::move(threadpool)
517-
));
518-
});
509+
, leanvec_dims_{training_data.get_leanvec_dims()}
510+
, leanvec_matrices_{training_data.get_leanvec_matrices()} {
511+
check_storage_kind(storage_kind);
519512
}
520513

521-
template <typename F, typename... Args>
522-
static auto dispatch_storage_kind(StorageKind kind, F&& f, Args&&... args) {
523-
using SK = StorageKind;
524-
using namespace svs::runtime::storage;
525-
switch (kind) {
526-
case SK::LeanVec4x4:
527-
return f(LeanVec4x4Tag{}, std::forward<Args>(args)...);
528-
case SK::LeanVec4x8:
529-
return f(LeanVec4x8Tag{}, std::forward<Args>(args)...);
530-
case SK::LeanVec8x8:
531-
return f(LeanVec8x8Tag{}, std::forward<Args>(args)...);
532-
default:
533-
throw ANNEXCEPTION("not supported SVS leanvec storage kind");
534-
}
535-
}
536-
537-
void init_impl(data::ConstSimpleDataView<float> data, std::span<const size_t> labels)
538-
override {
539-
impl_.reset(DynamicVamanaIndexLeanVecImplTrained::dispatch_storage_kind(
540-
this->storage_kind_,
541-
[this](
542-
auto&& tag,
543-
MetricType metric,
544-
data::ConstSimpleDataView<float> data,
545-
std::span<const size_t> labels
546-
) {
547-
using Tag = std::decay_t<decltype(tag)>;
548-
return DynamicVamanaIndexLeanVecImplTrained::init_impl_t(
549-
std::forward<Tag>(tag),
550-
metric,
551-
data,
552-
labels,
553-
training_data_.get_leanvec_matrices()
554-
);
555-
},
556-
metric_type_,
557-
data,
558-
labels
559-
));
560-
}
561-
562-
LeanVecTrainingDataImpl training_data_;
563-
};
564-
565-
struct DynamicVamanaIndexLeanVecImplDims : public DynamicVamanaIndexImpl {
566-
using DynamicVamanaIndexImpl::DynamicVamanaIndexImpl;
567-
DynamicVamanaIndexLeanVecImplDims(
514+
DynamicVamanaIndexLeanVecImpl(
568515
size_t dim,
569516
MetricType metric,
570517
StorageKind storage_kind,
@@ -573,72 +520,70 @@ struct DynamicVamanaIndexLeanVecImplDims : public DynamicVamanaIndexImpl {
573520
const VamanaIndex::SearchParams& default_search_params = {10, 10}
574521
)
575522
: DynamicVamanaIndexImpl{dim, metric, storage_kind, params, default_search_params}
576-
, leanvec_dims_{leanvec_dims} {}
577-
578-
template <typename Tag>
579-
svs::DynamicVamana* init_impl_t(
580-
Tag&& tag,
581-
MetricType metric,
582-
const svs::data::ConstSimpleDataView<float>& data,
583-
std::span<const size_t> labels,
584-
size_t leanvec_dims
585-
) {
586-
auto threadpool =
587-
svs::threads::ThreadPoolHandle(svs::threads::OMPThreadPool(omp_get_max_threads()
588-
));
589-
590-
auto storage = make_storage(std::forward<Tag>(tag), data, threadpool, leanvec_dims);
591-
592-
svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric));
593-
return distance_dispatcher([&](auto&& distance) {
594-
return new svs::DynamicVamana(svs::DynamicVamana::build<float>(
595-
this->vamana_build_parameters(),
596-
std::move(storage),
597-
std::move(labels),
598-
std::forward<decltype(distance)>(distance),
599-
std::move(threadpool)
600-
));
601-
});
523+
, leanvec_dims_{leanvec_dims}
524+
, leanvec_matrices_{std::nullopt} {
525+
check_storage_kind(storage_kind);
602526
}
603527

604528
template <typename F, typename... Args>
605-
static auto dispatch_storage_kind(StorageKind kind, F&& f, Args&&... args) {
606-
using SK = StorageKind;
607-
using namespace svs::runtime::storage;
529+
static auto dispatch_leanvec_storage_kind(StorageKind kind, F&& f, Args&&... args) {
608530
switch (kind) {
609-
case SK::LeanVec4x4:
610-
return f(LeanVec4x4Tag{}, std::forward<Args>(args)...);
611-
case SK::LeanVec4x8:
612-
return f(LeanVec4x8Tag{}, std::forward<Args>(args)...);
613-
case SK::LeanVec8x8:
614-
return f(LeanVec8x8Tag{}, std::forward<Args>(args)...);
531+
case StorageKind::LeanVec4x4:
532+
return f(storage::LeanVec4x4Tag{}, std::forward<Args>(args)...);
533+
case StorageKind::LeanVec4x8:
534+
return f(storage::LeanVec4x8Tag{}, std::forward<Args>(args)...);
535+
case StorageKind::LeanVec8x8:
536+
return f(storage::LeanVec8x8Tag{}, std::forward<Args>(args)...);
615537
default:
616-
throw ANNEXCEPTION("not supported SVS leanvec storage kind");
538+
throw StatusException{
539+
ErrorCode::INVALID_ARGUMENT, "SVS LeanVec storage kind required"};
617540
}
618541
}
619542

620543
void init_impl(data::ConstSimpleDataView<float> data, std::span<const size_t> labels)
621544
override {
622-
impl_.reset(DynamicVamanaIndexLeanVecImplDims::dispatch_storage_kind(
545+
assert(storage::is_leanvec_storage(this->storage_kind_));
546+
impl_.reset(dispatch_leanvec_storage_kind(
623547
this->storage_kind_,
624548
[this](
625549
auto&& tag,
626-
MetricType metric,
627550
data::ConstSimpleDataView<float> data,
628551
std::span<const size_t> labels
629552
) {
630553
using Tag = std::decay_t<decltype(tag)>;
631-
return DynamicVamanaIndexLeanVecImplDims::init_impl_t(
632-
std::forward<Tag>(tag), metric, data, labels, leanvec_dims_
554+
return DynamicVamanaIndexImpl::build_impl(
555+
std::forward<Tag>(tag),
556+
this->metric_type_,
557+
this->vamana_build_parameters(),
558+
data,
559+
labels,
560+
this->leanvec_dims_,
561+
this->leanvec_matrices_
633562
);
634563
},
635-
metric_type_,
636564
data,
637565
labels
638566
));
639567
}
640568

569+
protected:
641570
size_t leanvec_dims_;
571+
std::optional<LeanVecMatricesType> leanvec_matrices_;
572+
573+
StorageKind check_storage_kind(StorageKind kind) {
574+
if (!storage::is_leanvec_storage(kind)) {
575+
throw StatusException(
576+
ErrorCode::INVALID_ARGUMENT, "SVS LeanVec storage kind required"
577+
);
578+
}
579+
if (!svs::detail::lvq_leanvec_enabled()) {
580+
throw StatusException(
581+
ErrorCode::NOT_IMPLEMENTED,
582+
"LeanVec storage kind requested but not supported by CPU"
583+
);
584+
}
585+
return kind;
586+
}
642587
};
643588

644589
} // namespace runtime

0 commit comments

Comments
 (0)