Skip to content

Commit 771b1a8

Browse files
Alexandr Guzhvafacebook-github-bot
Alexandr Guzhva
authored andcommitted
Introduce transposed centroid table to speedup ProductQuantizer::compute_codes() (facebookresearch#2562)
Summary: Pull Request resolved: facebookresearch#2562 Introduce a table of transposed centroids in ProductQuantizer that significantly speeds up ProductQuantizer::compute_codes() call for certain PQ parameters, so speeds up search queries. * ::sync_tranposed_centroids() call is used to fill the table * ::clear_transposed_centroids() call clear the table, so that the original baseline code is used for ::compute_codes() Reviewed By: mdouze Differential Revision: D40763338 fbshipit-source-id: 87b40e5dd2f8c3cadeb94c1cd9e8a4a5b6ffa97d
1 parent 02ef6b6 commit 771b1a8

6 files changed

+466
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import faiss
2+
import time
3+
import random
4+
5+
import faiss.contrib.datasets
6+
7+
8+
# copied from benchs/bench_all_ivf/bench_all_ivf.py
9+
def unwind_index_ivf(index):
10+
if isinstance(index, faiss.IndexPreTransform):
11+
assert index.chain.size() == 1
12+
vt = index.chain.at(0)
13+
index_ivf, vt2 = unwind_index_ivf(faiss.downcast_index(index.index))
14+
assert vt2 is None
15+
return index_ivf, vt
16+
if hasattr(faiss, "IndexRefine") and isinstance(index, faiss.IndexRefine):
17+
return unwind_index_ivf(faiss.downcast_index(index.base_index))
18+
if isinstance(index, faiss.IndexIVF):
19+
return index, None
20+
else:
21+
return None, None
22+
23+
24+
def test_bigann10m(index_file, index_parameters):
25+
ds = faiss.contrib.datasets.DatasetBigANN(nb_M=10)
26+
27+
xq = ds.get_queries()
28+
xb = ds.get_database()
29+
gt = ds.get_groundtruth()
30+
31+
nb, d = xb.shape
32+
nq, d = xq.shape
33+
34+
print("Reading index {}".format(index_file))
35+
index = faiss.read_index(index_file)
36+
37+
ps = faiss.ParameterSpace()
38+
ps.initialize(index)
39+
40+
index_ivf, vec_transform = unwind_index_ivf(index)
41+
42+
print('params regular transp_centroids regular R@1 R@10 R@100')
43+
for index_parameter in index_parameters:
44+
ps.set_index_parameters(index, index_parameter)
45+
46+
print(index_parameter.ljust(70), end=' ')
47+
48+
k = 100
49+
50+
# warmup
51+
D, I = index.search(xq, k)
52+
53+
# warmup
54+
D, I = index.search(xq, k)
55+
56+
# eval
57+
t2_0 = time.time()
58+
D, I = index.search(xq, k)
59+
t2_1 = time.time()
60+
61+
# eval
62+
index_ivf.pq.sync_transposed_centroids()
63+
t3_0 = time.time()
64+
D, I = index.search(xq, k)
65+
t3_1 = time.time()
66+
67+
# eval
68+
index_ivf.pq.clear_transposed_centroids()
69+
t4_0 = time.time()
70+
D, I = index.search(xq, k)
71+
t4_1 = time.time()
72+
73+
print(" %9.5f " % (t2_1 - t2_0), end=' ')
74+
print(" %9.5f " % (t3_1 - t3_0), end=' ')
75+
print(" %9.5f " % (t4_1 - t4_0), end=' ')
76+
77+
for rank in 1, 10, 100:
78+
n_ok = (I[:, :rank] == gt[:, :1]).sum()
79+
print("%.4f" % (n_ok / float(nq)), end=' ')
80+
print()
81+
82+
83+
if __name__ == "__main__":
84+
faiss.contrib.datasets.dataset_basedir = '/home/aguzhva/ANN_SIFT1B/'
85+
86+
# represents OPQ32_128,IVF65536_HNSW32,PQ32 index
87+
index_file_1 = "/home/aguzhva/ANN_SIFT1B/run_tests/bench_ivf/indexes/hnsw32/.faissindex"
88+
89+
nprobe_values = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]
90+
quantizer_efsearch_values = [4, 8, 16, 32, 64, 128, 256, 512]
91+
ht_values = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102, 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126, 128, 256]
92+
93+
# represents OPQ32_128,IVF65536(IVF256,PQHDx4fs,RFlat),PQ32 index
94+
index_file_2 = "/home/aguzhva/ANN_SIFT1B/run_tests/bench_ivf/indexes/pq4/.faissindex"
95+
96+
quantizer_k_factor_rf_values = [1, 2, 4, 8, 16, 32, 64]
97+
quantizer_nprobe_values = [1, 2, 4, 8, 16, 32, 64, 128]
98+
99+
# test the first index
100+
index_parameters_1 = []
101+
for _ in range(0, 20):
102+
nprobe = random.choice(nprobe_values)
103+
quantizer_efsearch = random.choice(quantizer_efsearch_values)
104+
ht = random.choice(ht_values)
105+
index_parameters_1.append(
106+
"nprobe={},quantizer_efSearch={},ht={}".format(
107+
nprobe,
108+
quantizer_efsearch,
109+
ht)
110+
)
111+
112+
test_bigann10m(index_file_1, index_parameters_1)
113+
114+
# test the second index
115+
index_parameters_2 = []
116+
for _ in range(0, 20):
117+
nprobe = random.choice(nprobe_values)
118+
quantizer_k_factor_rf = random.choice(quantizer_k_factor_rf_values)
119+
quantizer_nprobe = random.choice(quantizer_nprobe_values)
120+
ht = random.choice(ht_values)
121+
index_parameters_2.append(
122+
"nprobe={},quantizer_k_factor_rf={},quantizer_nprobe={},ht={}".format(
123+
nprobe,
124+
quantizer_k_factor_rf,
125+
quantizer_nprobe,
126+
ht)
127+
)
128+
129+
test_bigann10m(index_file_2, index_parameters_2)

faiss/impl/ProductQuantizer.cpp

+48-6
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,26 @@ void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
237237
for (size_t m = 0; m < pq.M; m++) {
238238
const float* xsub = x + m * pq.dsub;
239239

240-
uint64_t idxm = fvec_L2sqr_ny_nearest(
241-
distances.data(),
242-
xsub,
243-
pq.get_centroids(m, 0),
244-
pq.dsub,
245-
pq.ksub);
240+
uint64_t idxm = 0;
241+
if (pq.transposed_centroids.empty()) {
242+
// the regular version
243+
idxm = fvec_L2sqr_ny_nearest(
244+
distances.data(),
245+
xsub,
246+
pq.get_centroids(m, 0),
247+
pq.dsub,
248+
pq.ksub);
249+
} else {
250+
// transposed centroids are available, use'em
251+
idxm = fvec_L2sqr_ny_nearest_y_transposed(
252+
distances.data(),
253+
xsub,
254+
pq.transposed_centroids.data() + m * pq.ksub,
255+
pq.centroids_sq_lengths.data() + m * pq.ksub,
256+
pq.dsub,
257+
pq.M * pq.ksub,
258+
pq.ksub);
259+
}
246260

247261
encoder.encode(idxm);
248262
}
@@ -819,4 +833,32 @@ void ProductQuantizer::search_sdc(
819833
}
820834
}
821835

836+
void ProductQuantizer::sync_transposed_centroids() {
837+
transposed_centroids.resize(d * ksub);
838+
centroids_sq_lengths.resize(ksub * M);
839+
840+
for (size_t mi = 0; mi < M; mi++) {
841+
for (size_t ki = 0; ki < ksub; ki++) {
842+
float sqlen = 0;
843+
844+
for (size_t di = 0; di < dsub; di++) {
845+
const float q = centroids[(mi * ksub + ki) * dsub + di];
846+
847+
transposed_centroids[(di * M + mi) * ksub + ki] = q;
848+
sqlen += q * q;
849+
}
850+
851+
centroids_sq_lengths[mi * ksub + ki] = sqlen;
852+
}
853+
}
854+
}
855+
856+
void ProductQuantizer::clear_transposed_centroids() {
857+
transposed_centroids.clear();
858+
transposed_centroids.shrink_to_fit();
859+
860+
centroids_sq_lengths.clear();
861+
centroids_sq_lengths.shrink_to_fit();
862+
}
863+
822864
} // namespace faiss

faiss/impl/ProductQuantizer.h

+17-1
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,18 @@ struct ProductQuantizer : Quantizer {
4949
/// d / M)
5050
Index* assign_index;
5151

52-
/// Centroid table, size M * ksub * dsub
52+
/// Centroid table, size M * ksub * dsub.
53+
/// Layout: (M, ksub, dsub)
5354
std::vector<float> centroids;
5455

56+
/// Transposed centroid table, size M * ksub * dsub.
57+
/// Layout: (dsub, M, ksub)
58+
std::vector<float> transposed_centroids;
59+
60+
/// Squared lengths of centroids, size M * ksub
61+
/// Layout: (M, ksub)
62+
std::vector<float> centroids_sq_lengths;
63+
5564
/// return the centroids associated with subvector m
5665
float* get_centroids(size_t m, size_t i) {
5766
return &centroids[(m * ksub + i) * dsub];
@@ -165,6 +174,13 @@ struct ProductQuantizer : Quantizer {
165174
const size_t ncodes,
166175
float_maxheap_array_t* res,
167176
bool init_finalize_heap = true) const;
177+
178+
/// Sync transposed centroids with regular centroids. This call
179+
/// is needed if centroids were edited directly.
180+
void sync_transposed_centroids();
181+
182+
/// Clear transposed centroids table so ones are no longer used.
183+
void clear_transposed_centroids();
168184
};
169185

170186
// block size used in ProductQuantizer::compute_codes

faiss/utils/distances.h

+13
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ size_t fvec_L2sqr_ny_nearest(
8383
size_t d,
8484
size_t ny);
8585

86+
/* compute ny square L2 distance between x and a set of transposed contiguous
87+
y vectors and return the index of the nearest vector.
88+
squared lengths of y should be provided as well
89+
return 0 if ny == 0. */
90+
size_t fvec_L2sqr_ny_nearest_y_transposed(
91+
float* distances_tmp_buffer,
92+
const float* x,
93+
const float* y,
94+
const float* y_sqlen,
95+
size_t d,
96+
size_t d_offset,
97+
size_t ny);
98+
8699
/** squared norm of a vector */
87100
float fvec_norm_L2sqr(const float* x, size_t d);
88101

0 commit comments

Comments
 (0)