1515
1616#include < faiss/impl/platform_macros.h>
1717#include < faiss/utils/Heap.h>
18+ #include < faiss/utils/simd_levels.h>
1819
1920namespace faiss {
2021
@@ -27,15 +28,27 @@ struct IDSelector;
2728// / Squared L2 distance between two vectors
2829float fvec_L2sqr (const float * x, const float * y, size_t d);
2930
31+ template <SIMDLevel>
32+ float fvec_L2sqr (const float * x, const float * y, size_t d);
33+
3034// / inner product
3135float fvec_inner_product (const float * x, const float * y, size_t d);
3236
37+ template <SIMDLevel>
38+ float fvec_inner_product (const float * x, const float * y, size_t d);
39+
3340// / L1 distance
3441float fvec_L1 (const float * x, const float * y, size_t d);
3542
43+ template <SIMDLevel>
44+ float fvec_L1 (const float * x, const float * y, size_t d);
45+
3646// / infinity distance
3747float fvec_Linf (const float * x, const float * y, size_t d);
3848
49+ template <SIMDLevel>
50+ float fvec_Linf (const float * x, const float * y, size_t d);
51+
3952// / Special version of inner product that computes 4 distances
4053// / between x and yi, which is performance oriented.
4154void fvec_inner_product_batch_4 (
@@ -50,6 +63,19 @@ void fvec_inner_product_batch_4(
5063 float & dis2,
5164 float & dis3);
5265
66+ template <SIMDLevel>
67+ void fvec_inner_product_batch_4 (
68+ const float * x,
69+ const float * y0,
70+ const float * y1,
71+ const float * y2,
72+ const float * y3,
73+ const size_t d,
74+ float & dis0,
75+ float & dis1,
76+ float & dis2,
77+ float & dis3);
78+
5379// / Special version of L2sqr that computes 4 distances
5480// / between x and yi, which is performance oriented.
5581void fvec_L2sqr_batch_4 (
@@ -64,6 +90,19 @@ void fvec_L2sqr_batch_4(
6490 float & dis2,
6591 float & dis3);
6692
93+ template <SIMDLevel>
94+ void fvec_L2sqr_batch_4 (
95+ const float * x,
96+ const float * y0,
97+ const float * y1,
98+ const float * y2,
99+ const float * y3,
100+ const size_t d,
101+ float & dis0,
102+ float & dis1,
103+ float & dis2,
104+ float & dis3);
105+
67106/* * Compute pairwise distances between sets of vectors
68107 *
69108 * @param d dimension of the vectors
@@ -93,6 +132,14 @@ void fvec_inner_products_ny(
93132 size_t d,
94133 size_t ny);
95134
135+ template <SIMDLevel>
136+ void fvec_inner_products_ny (
137+ float * ip, /* output inner product */
138+ const float * x,
139+ const float * y,
140+ size_t d,
141+ size_t ny);
142+
96143/* compute ny square L2 distance between x and a set of contiguous y vectors */
97144void fvec_L2sqr_ny (
98145 float * dis,
@@ -101,6 +148,14 @@ void fvec_L2sqr_ny(
101148 size_t d,
102149 size_t ny);
103150
151+ template <SIMDLevel>
152+ void fvec_L2sqr_ny (
153+ float * dis,
154+ const float * x,
155+ const float * y,
156+ size_t d,
157+ size_t ny);
158+
104159/* compute ny square L2 distance between x and a set of transposed contiguous
105160 y vectors. squared lengths of y should be provided as well */
106161void fvec_L2sqr_ny_transposed (
@@ -112,6 +167,16 @@ void fvec_L2sqr_ny_transposed(
112167 size_t d_offset,
113168 size_t ny);
114169
170+ template <SIMDLevel>
171+ void fvec_L2sqr_ny_transposed (
172+ float * dis,
173+ const float * x,
174+ const float * y,
175+ const float * y_sqlen,
176+ size_t d,
177+ size_t d_offset,
178+ size_t ny);
179+
115180/* compute ny square L2 distance between x and a set of contiguous y vectors
116181 and return the index of the nearest vector.
117182 return 0 if ny == 0. */
@@ -122,6 +187,14 @@ size_t fvec_L2sqr_ny_nearest(
122187 size_t d,
123188 size_t ny);
124189
190+ template <SIMDLevel>
191+ size_t fvec_L2sqr_ny_nearest (
192+ float * distances_tmp_buffer,
193+ const float * x,
194+ const float * y,
195+ size_t d,
196+ size_t ny);
197+
125198/* compute ny square L2 distance between x and a set of transposed contiguous
126199 y vectors and return the index of the nearest vector.
127200 squared lengths of y should be provided as well
@@ -135,9 +208,22 @@ size_t fvec_L2sqr_ny_nearest_y_transposed(
135208 size_t d_offset,
136209 size_t ny);
137210
211+ template <SIMDLevel>
212+ size_t fvec_L2sqr_ny_nearest_y_transposed (
213+ float * distances_tmp_buffer,
214+ const float * x,
215+ const float * y,
216+ const float * y_sqlen,
217+ size_t d,
218+ size_t d_offset,
219+ size_t ny);
220+
138221/* * squared norm of a vector */
139222float fvec_norm_L2sqr (const float * x, size_t d);
140223
224+ template <SIMDLevel>
225+ float fvec_norm_L2sqr (const float * x, size_t d);
226+
141227/* * compute the L2 norms for a set of vectors
142228 *
143229 * @param norms output norms, size nx
@@ -473,6 +559,10 @@ void compute_PQ_dis_tables_dsub2(
473559 */
474560void fvec_madd (size_t n, const float * a, float bf, const float * b, float * c);
475561
562+ /* same statically */
563+ template <SIMDLevel>
564+ void fvec_madd (size_t n, const float * a, float bf, const float * b, float * c);
565+
476566/* * same as fvec_madd, also return index of the min of the result table
477567 * @return index of the min of table c
478568 */
@@ -483,4 +573,12 @@ int fvec_madd_and_argmin(
483573 const float * b,
484574 float * c);
485575
576+ template <SIMDLevel>
577+ int fvec_madd_and_argmin (
578+ size_t n,
579+ const float * a,
580+ float bf,
581+ const float * b,
582+ float * c);
583+
486584} // namespace faiss
0 commit comments