Skip to content

Commit b236273

Browse files
Subhadeep Karanfacebook-github-bot
authored andcommitted
dynamic dispatch distances_simd (#4553)
Summary: `fvec_madd` is the first function to test dispatching to AVX and AVX512 distances_simd.cpp is split into specialized files distances_avx2.cpp distances_avx512.cpp that are compiled with appropriate flags. Reviewed By: mnorris11 Differential Revision: D72937708
1 parent 038db06 commit b236273

14 files changed

+4197
-3730
lines changed

faiss/utils/distances.h

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <faiss/impl/platform_macros.h>
1717
#include <faiss/utils/Heap.h>
18+
#include <faiss/utils/simd_levels.h>
1819

1920
namespace faiss {
2021

@@ -27,15 +28,27 @@ struct IDSelector;
2728
/// Squared L2 distance between two vectors
2829
float fvec_L2sqr(const float* x, const float* y, size_t d);
2930

31+
template <SIMDLevel>
32+
float fvec_L2sqr(const float* x, const float* y, size_t d);
33+
3034
/// inner product
3135
float fvec_inner_product(const float* x, const float* y, size_t d);
3236

37+
template <SIMDLevel>
38+
float fvec_inner_product(const float* x, const float* y, size_t d);
39+
3340
/// L1 distance
3441
float fvec_L1(const float* x, const float* y, size_t d);
3542

43+
template <SIMDLevel>
44+
float fvec_L1(const float* x, const float* y, size_t d);
45+
3646
/// infinity distance
3747
float fvec_Linf(const float* x, const float* y, size_t d);
3848

49+
template <SIMDLevel>
50+
float fvec_Linf(const float* x, const float* y, size_t d);
51+
3952
/// Special version of inner product that computes 4 distances
4053
/// between x and yi, which is performance oriented.
4154
void fvec_inner_product_batch_4(
@@ -50,6 +63,19 @@ void fvec_inner_product_batch_4(
5063
float& dis2,
5164
float& dis3);
5265

66+
template <SIMDLevel>
67+
void fvec_inner_product_batch_4(
68+
const float* x,
69+
const float* y0,
70+
const float* y1,
71+
const float* y2,
72+
const float* y3,
73+
const size_t d,
74+
float& dis0,
75+
float& dis1,
76+
float& dis2,
77+
float& dis3);
78+
5379
/// Special version of L2sqr that computes 4 distances
5480
/// between x and yi, which is performance oriented.
5581
void fvec_L2sqr_batch_4(
@@ -64,6 +90,19 @@ void fvec_L2sqr_batch_4(
6490
float& dis2,
6591
float& dis3);
6692

93+
template <SIMDLevel>
94+
void fvec_L2sqr_batch_4(
95+
const float* x,
96+
const float* y0,
97+
const float* y1,
98+
const float* y2,
99+
const float* y3,
100+
const size_t d,
101+
float& dis0,
102+
float& dis1,
103+
float& dis2,
104+
float& dis3);
105+
67106
/** Compute pairwise distances between sets of vectors
68107
*
69108
* @param d dimension of the vectors
@@ -93,6 +132,14 @@ void fvec_inner_products_ny(
93132
size_t d,
94133
size_t ny);
95134

135+
template <SIMDLevel>
136+
void fvec_inner_products_ny(
137+
float* ip, /* output inner product */
138+
const float* x,
139+
const float* y,
140+
size_t d,
141+
size_t ny);
142+
96143
/* compute ny square L2 distance between x and a set of contiguous y vectors */
97144
void fvec_L2sqr_ny(
98145
float* dis,
@@ -101,6 +148,14 @@ void fvec_L2sqr_ny(
101148
size_t d,
102149
size_t ny);
103150

151+
template <SIMDLevel>
152+
void fvec_L2sqr_ny(
153+
float* dis,
154+
const float* x,
155+
const float* y,
156+
size_t d,
157+
size_t ny);
158+
104159
/* compute ny square L2 distance between x and a set of transposed contiguous
105160
y vectors. squared lengths of y should be provided as well */
106161
void fvec_L2sqr_ny_transposed(
@@ -112,6 +167,16 @@ void fvec_L2sqr_ny_transposed(
112167
size_t d_offset,
113168
size_t ny);
114169

170+
template <SIMDLevel>
171+
void fvec_L2sqr_ny_transposed(
172+
float* dis,
173+
const float* x,
174+
const float* y,
175+
const float* y_sqlen,
176+
size_t d,
177+
size_t d_offset,
178+
size_t ny);
179+
115180
/* compute ny square L2 distance between x and a set of contiguous y vectors
116181
and return the index of the nearest vector.
117182
return 0 if ny == 0. */
@@ -122,6 +187,14 @@ size_t fvec_L2sqr_ny_nearest(
122187
size_t d,
123188
size_t ny);
124189

190+
template <SIMDLevel>
191+
size_t fvec_L2sqr_ny_nearest(
192+
float* distances_tmp_buffer,
193+
const float* x,
194+
const float* y,
195+
size_t d,
196+
size_t ny);
197+
125198
/* compute ny square L2 distance between x and a set of transposed contiguous
126199
y vectors and return the index of the nearest vector.
127200
squared lengths of y should be provided as well
@@ -135,9 +208,22 @@ size_t fvec_L2sqr_ny_nearest_y_transposed(
135208
size_t d_offset,
136209
size_t ny);
137210

211+
template <SIMDLevel>
212+
size_t fvec_L2sqr_ny_nearest_y_transposed(
213+
float* distances_tmp_buffer,
214+
const float* x,
215+
const float* y,
216+
const float* y_sqlen,
217+
size_t d,
218+
size_t d_offset,
219+
size_t ny);
220+
138221
/** squared norm of a vector */
139222
float fvec_norm_L2sqr(const float* x, size_t d);
140223

224+
template <SIMDLevel>
225+
float fvec_norm_L2sqr(const float* x, size_t d);
226+
141227
/** compute the L2 norms for a set of vectors
142228
*
143229
* @param norms output norms, size nx
@@ -473,6 +559,10 @@ void compute_PQ_dis_tables_dsub2(
473559
*/
474560
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c);
475561

562+
/* same statically */
563+
template <SIMDLevel>
564+
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c);
565+
476566
/** same as fvec_madd, also return index of the min of the result table
477567
* @return index of the min of table c
478568
*/
@@ -483,4 +573,12 @@ int fvec_madd_and_argmin(
483573
const float* b,
484574
float* c);
485575

576+
template <SIMDLevel>
577+
int fvec_madd_and_argmin(
578+
size_t n,
579+
const float* a,
580+
float bf,
581+
const float* b,
582+
float* c);
583+
486584
} // namespace faiss

0 commit comments

Comments
 (0)