Skip to content

Commit

Permalink
Black
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Apr 7, 2023
1 parent 47c5288 commit 0c67adf
Show file tree
Hide file tree
Showing 58 changed files with 1,594 additions and 1,577 deletions.
1 change: 1 addition & 0 deletions ann_benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from __future__ import absolute_import

# from ann_benchmarks.main import *
3 changes: 1 addition & 2 deletions ann_benchmarks/algorithms/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,4 @@ def query(self, v, n):
return self._annoy.get_nns_by_vector(v.tolist(), n, self._search_k)

def __str__(self):
return 'Annoy(n_trees=%d, search_k=%d)' % (self._n_trees,
self._search_k)
return "Annoy(n_trees=%d, search_k=%d)" % (self._n_trees, self._search_k)
10 changes: 5 additions & 5 deletions ann_benchmarks/algorithms/balltree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ class BallTree(BaseANN):
def __init__(self, metric, leaf_size=20):
self._leaf_size = leaf_size
self._metric = metric
self.name = 'BallTree(leaf_size=%d)' % self._leaf_size
self.name = "BallTree(leaf_size=%d)" % self._leaf_size

def fit(self, X):
if self._metric == 'angular':
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
if self._metric == "angular":
X = sklearn.preprocessing.normalize(X, axis=1, norm="l2")
self._tree = sklearn.neighbors.BallTree(X, leaf_size=self._leaf_size)

def query(self, v, n):
if self._metric == 'angular':
v = sklearn.preprocessing.normalize([v], axis=1, norm='l2')[0]
if self._metric == "angular":
v = sklearn.preprocessing.normalize([v], axis=1, norm="l2")[0]
dist, ind = self._tree.query([v], k=n)
return ind[0]
4 changes: 2 additions & 2 deletions ann_benchmarks/algorithms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def query(self, q, n):

def batch_query(self, X, n):
"""Provide all queries at once and let algorithm figure out
how to handle it. Default implementation uses a ThreadPool
to parallelize query processing."""
how to handle it. Default implementation uses a ThreadPool
to parallelize query processing."""
pool = ThreadPool()
self.res = pool.map(lambda q: self.query(q, n), X)

Expand Down
64 changes: 29 additions & 35 deletions ann_benchmarks/algorithms/bruteforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,65 +7,59 @@

class BruteForce(BaseANN):
def __init__(self, metric):
if metric not in ('angular', 'euclidean', 'hamming'):
raise NotImplementedError(
"BruteForce doesn't support metric %s" % metric)
if metric not in ("angular", "euclidean", "hamming"):
raise NotImplementedError("BruteForce doesn't support metric %s" % metric)
self._metric = metric
self.name = 'BruteForce()'
self.name = "BruteForce()"

def fit(self, X):
metric = {'angular': 'cosine', 'euclidean': 'l2',
'hamming': 'hamming'}[self._metric]
self._nbrs = sklearn.neighbors.NearestNeighbors(
algorithm='brute', metric=metric)
metric = {"angular": "cosine", "euclidean": "l2", "hamming": "hamming"}[self._metric]
self._nbrs = sklearn.neighbors.NearestNeighbors(algorithm="brute", metric=metric)
self._nbrs.fit(X)

def query(self, v, n):
return list(self._nbrs.kneighbors(
[v], return_distance=False, n_neighbors=n)[0])
return list(self._nbrs.kneighbors([v], return_distance=False, n_neighbors=n)[0])

def query_with_distances(self, v, n):
(distances, positions) = self._nbrs.kneighbors(
[v], return_distance=True, n_neighbors=n)
(distances, positions) = self._nbrs.kneighbors([v], return_distance=True, n_neighbors=n)
return zip(list(positions[0]), list(distances[0]))


class BruteForceBLAS(BaseANN):
"""kNN search that uses a linear scan = brute force."""

def __init__(self, metric, precision=numpy.float32):
if metric not in ('angular', 'euclidean', 'hamming', 'jaccard'):
raise NotImplementedError(
"BruteForceBLAS doesn't support metric %s" % metric)
elif metric == 'hamming' and precision != numpy.bool_:
if metric not in ("angular", "euclidean", "hamming", "jaccard"):
raise NotImplementedError("BruteForceBLAS doesn't support metric %s" % metric)
elif metric == "hamming" and precision != numpy.bool_:
raise NotImplementedError(
"BruteForceBLAS doesn't support precision"
" %s with Hamming distances" % precision)
"BruteForceBLAS doesn't support precision" " %s with Hamming distances" % precision
)
self._metric = metric
self._precision = precision
self.name = 'BruteForceBLAS()'
self.name = "BruteForceBLAS()"

def fit(self, X):
"""Initialize the search index."""
if self._metric == 'angular':
if self._metric == "angular":
# precompute (squared) length of each vector
lens = (X ** 2).sum(-1)
lens = (X**2).sum(-1)
# normalize index vectors to unit length
X /= numpy.sqrt(lens)[..., numpy.newaxis]
self.index = numpy.ascontiguousarray(X, dtype=self._precision)
elif self._metric == 'hamming':
elif self._metric == "hamming":
# Regarding bitvectors as vectors in l_2 is faster for blas
X = X.astype(numpy.float32)
# precompute (squared) length of each vector
lens = (X ** 2).sum(-1)
lens = (X**2).sum(-1)
self.index = numpy.ascontiguousarray(X, dtype=numpy.float32)
self.lengths = numpy.ascontiguousarray(lens, dtype=numpy.float32)
elif self._metric == 'euclidean':
elif self._metric == "euclidean":
# precompute (squared) length of each vector
lens = (X ** 2).sum(-1)
lens = (X**2).sum(-1)
self.index = numpy.ascontiguousarray(X, dtype=self._precision)
self.lengths = numpy.ascontiguousarray(lens, dtype=self._precision)
elif self._metric == 'jaccard':
elif self._metric == "jaccard":
self.index = X
else:
# shouldn't get past the constructor!
Expand All @@ -78,33 +72,33 @@ def query_with_distances(self, v, n):
"""Find indices of `n` most similar vectors from the index to query
vector `v`."""

if self._metric != 'jaccard':
if self._metric != "jaccard":
# use same precision for query as for index
v = numpy.ascontiguousarray(v, dtype=self.index.dtype)

# HACK we ignore query length as that's a constant
# not affecting the final ordering
if self._metric == 'angular':
if self._metric == "angular":
# argmax_a cossim(a, b) = argmax_a dot(a, b) / |a||b| = argmin_a -dot(a, b) # noqa
dists = -numpy.dot(self.index, v)
elif self._metric == 'euclidean':
elif self._metric == "euclidean":
# argmin_a (a - b)^2 = argmin_a a^2 - 2ab + b^2 = argmin_a a^2 - 2ab # noqa
dists = self.lengths - 2 * numpy.dot(self.index, v)
elif self._metric == 'hamming':
elif self._metric == "hamming":
# Just compute hamming distance using euclidean distance
dists = self.lengths - 2 * numpy.dot(self.index, v)
elif self._metric == 'jaccard':
dists = [pd[self._metric]['distance'](v, e) for e in self.index]
elif self._metric == "jaccard":
dists = [pd[self._metric]["distance"](v, e) for e in self.index]
else:
# shouldn't get past the constructor!
assert False, "invalid metric"
# partition-sort by distance, get `n` closest
nearest_indices = numpy.argpartition(dists, n)[:n]
indices = [idx for idx in nearest_indices if pd[self._metric]
["distance_valid"](dists[idx])]
indices = [idx for idx in nearest_indices if pd[self._metric]["distance_valid"](dists[idx])]

def fix(index):
ep = self.index[index]
ev = v
return (index, pd[self._metric]['distance'](ep, ev))
return (index, pd[self._metric]["distance"](ep, ev))

return map(fix, indices)
2 changes: 1 addition & 1 deletion ann_benchmarks/algorithms/ckdtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class CKDTree(BaseANN):
def __init__(self, metric, leaf_size=20):
self._leaf_size = leaf_size
self._metric = metric
self.name = 'CKDTree(leaf_size=%d)' % self._leaf_size
self.name = "CKDTree(leaf_size=%d)" % self._leaf_size

def fit(self, X):
self._tree = cKDTree(X, leafsize=self._leaf_size)
Expand Down
15 changes: 7 additions & 8 deletions ann_benchmarks/algorithms/datasketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,32 @@

class DataSketch(BaseANN):
def __init__(self, metric, n_perm, n_rep):
if metric not in ('jaccard'):
raise NotImplementedError(
"Datasketch doesn't support metric %s" % metric)
if metric not in ("jaccard"):
raise NotImplementedError("Datasketch doesn't support metric %s" % metric)
self._n_perm = n_perm
self._n_rep = n_rep
self._metric = metric
self.name = 'Datasketch(n_perm=%d, n_rep=%d)' % (n_perm, n_rep)
self.name = "Datasketch(n_perm=%d, n_rep=%d)" % (n_perm, n_rep)

def fit(self, X):
self._index = MinHashLSHForest(num_perm=self._n_perm, l=self._n_rep)
for i, x in enumerate(X):
m = MinHash(num_perm=self._n_perm)
if x.dtype == np.bool_:
for e in np.flatnonzero(x):
m.update(str(e).encode('utf8'))
m.update(str(e).encode("utf8"))
else:
for e in x:
m.update(str(e).encode('utf8'))
m.update(str(e).encode("utf8"))
self._index.add(str(i), m)
self._index.index()

def query(self, v, n):
m = MinHash(num_perm=self._n_perm)
if v.dtype == np.bool_:
for e in np.flatnonzero(v):
m.update(str(e).encode('utf8'))
m.update(str(e).encode("utf8"))
else:
for e in v:
m.update(str(e).encode('utf8'))
m.update(str(e).encode("utf8"))
return map(int, self._index.query(m, n))
49 changes: 21 additions & 28 deletions ann_benchmarks/algorithms/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@


Definition = collections.namedtuple(
'Definition',
['algorithm', 'constructor', 'module', 'docker_tag',
'arguments', 'query_argument_groups', 'disabled'])
"Definition", ["algorithm", "constructor", "module", "docker_tag", "arguments", "query_argument_groups", "disabled"]
)


def instantiate_algorithm(definition):
print('Trying to instantiate %s.%s(%s)' %
(definition.module, definition.constructor, definition.arguments))
print("Trying to instantiate %s.%s(%s)" % (definition.module, definition.constructor, definition.arguments))
module = importlib.import_module(definition.module)
constructor = getattr(module, definition.constructor)
return constructor(*definition.arguments)
Expand Down Expand Up @@ -55,8 +53,7 @@ def _generate_combinations(args):

def _substitute_variables(arg, vs):
if isinstance(arg, dict):
return dict([(k, _substitute_variables(v, vs))
for k, v in arg.items()])
return dict([(k, _substitute_variables(v, vs)) for k, v in arg.items()])
elif isinstance(arg, list):
return [_substitute_variables(a, vs) for a in arg]
elif isinstance(arg, str) and arg in vs:
Expand All @@ -73,13 +70,13 @@ def _get_definitions(definition_file):
def list_algorithms(definition_file):
definitions = _get_definitions(definition_file)

print('The following algorithms are supported...')
print("The following algorithms are supported...")
for point in definitions:
print('\t... for the point type "%s"...' % point)
for metric in definitions[point]:
print('\t\t... and the distance metric "%s":' % metric)
for algorithm in definitions[point][metric]:
print('\t\t\t%s' % algorithm)
print("\t\t\t%s" % algorithm)


def get_unique_algorithms(definition_file):
Expand All @@ -92,8 +89,7 @@ def get_unique_algorithms(definition_file):
return list(sorted(algos))


def get_definitions(definition_file, dimension, point_type="float",
distance_metric="euclidean", count=10):
def get_definitions(definition_file, dimension, point_type="float", distance_metric="euclidean", count=10):
definitions = _get_definitions(definition_file)

algorithm_definitions = {}
Expand All @@ -103,10 +99,9 @@ def get_definitions(definition_file, dimension, point_type="float",

definitions = []
for (name, algo) in algorithm_definitions.items():
for k in ['docker-tag', 'module', 'constructor']:
for k in ["docker-tag", "module", "constructor"]:
if k not in algo:
raise Exception(
'algorithm %s does not define a "%s" property' % (name, k))
raise Exception('algorithm %s does not define a "%s" property' % (name, k))

base_args = []
if "base-args" in algo:
Expand Down Expand Up @@ -150,20 +145,18 @@ def get_definitions(definition_file, dimension, point_type="float",
else:
aargs.append(arg_group)

vs = {
"@count": count,
"@metric": distance_metric,
"@dimension": dimension
}
vs = {"@count": count, "@metric": distance_metric, "@dimension": dimension}
aargs = [_substitute_variables(arg, vs) for arg in aargs]
definitions.append(Definition(
algorithm=name,
docker_tag=algo['docker-tag'],
module=algo['module'],
constructor=algo['constructor'],
arguments=aargs,
query_argument_groups=query_args,
disabled=algo.get('disabled', False)
))
definitions.append(
Definition(
algorithm=name,
docker_tag=algo["docker-tag"],
module=algo["module"],
constructor=algo["constructor"],
arguments=aargs,
query_argument_groups=query_args,
disabled=algo.get("disabled", False),
)
)

return definitions
Loading

0 comments on commit 0c67adf

Please sign in to comment.