forked from FALCONN-LIB/FALCONN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* sparse_GT * [WIP] add sparse datasets to Datasets * sparse Dataset object * [WIP] sparse dataset documentation * cleanup * fix gunzip flow * basic sparse baseline and eval * drop faiss-based intersection --------- Co-authored-by: Matthijs Douze <[email protected]> Co-authored-by: Martin Aumueller <[email protected]>
- Loading branch information
1 parent
f374b4f
commit 11a2ff2
Showing
5 changed files
with
441 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import argparse | ||
from tqdm import tqdm | ||
import time | ||
import numpy as np | ||
from scipy.sparse import csr_matrix, hstack | ||
from multiprocessing.pool import ThreadPool | ||
|
||
from benchmark.dataset_io import usbin_write, read_sparse_matrix | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
|
||
def aa(*args, **kwargs): | ||
group.add_argument(*args, **kwargs) | ||
|
||
group = parser.add_argument_group('File location') | ||
aa('--base_csr_file', required=True, help="location of a .csr file representing the base data") | ||
aa('--query_csr_file', required=True, help="location of a .csr file representing the query data") | ||
aa('--output_file', required=True, help="location of the ground truth file to be generated") | ||
|
||
group = parser.add_argument_group('Computation options') | ||
aa('--k', default=10, type=int, help="number of nearest kNN neighbors to search") | ||
# aa("--maxRAM", default=100, type=int, help="set max RSS in GB (avoid OOM crash)") | ||
aa('--nt', type=int, help="# of processes in thread pool. If omitted, then a single thread is used.") | ||
|
||
args = parser.parse_args() | ||
|
||
print("args:", args) | ||
print('k: ', args.k) | ||
|
||
data = read_sparse_matrix(args.base_csr_file) | ||
queries = read_sparse_matrix(args.query_csr_file) | ||
print('data:', data.shape) | ||
print('queries:', queries.shape) | ||
|
||
# pad the queries with virtual zeros to match the length of the data files | ||
queries = hstack([queries, csr_matrix((queries.shape[0], data.shape[1] - queries.shape[1]))]) | ||
print('after padding: ') | ||
print('data:', data.shape) | ||
print('queries:', queries.shape) | ||
|
||
k = args.k | ||
nq = queries.shape[0] | ||
|
||
D = np.zeros((nq, k), dtype='float32') | ||
I = -np.ones((nq, k), dtype='int32') | ||
|
||
|
||
def process_single_row(i): | ||
res = data.dot(queries.getrow(i).transpose()) | ||
ra = res.toarray() | ||
top_ind = np.argpartition(ra, -k, axis=0)[-k:][:, 0] | ||
|
||
index_and_dot_prod = [(i, ra[i, 0]) for i in top_ind] | ||
index_and_dot_prod.sort(key=lambda a: a[1], reverse=True) | ||
|
||
return index_and_dot_prod | ||
|
||
start = time.time() | ||
# single thread | ||
if args.nt is None: | ||
print('computing ground truth for', nq, 'queries (single thread):') | ||
res = [] | ||
for i in tqdm(range(nq)): # tqdm(range(nq)) | ||
res.append(process_single_row(i)) | ||
else: | ||
print('computing ground truth for', nq, 'queries (' + str(args.nt) + ' threads):') | ||
with ThreadPool(processes=args.nt) as pool: | ||
# Map the function to the array of items | ||
# res = pool.map(process_single_row, range(nq)) | ||
res = list(tqdm(pool.imap(process_single_row, range(nq)), total=nq)) | ||
|
||
end = time.time() | ||
elapsed = end - start | ||
print(f'Elapsed {elapsed}s for {nq} queries ({nq / elapsed} QPS) ') | ||
|
||
# rearrange to match the format of usbin_write: | ||
for i in range(nq): | ||
I[i, :] = [p[0] for p in res[i]] | ||
D[i, :] = [p[1] for p in res[i]] | ||
|
||
print() | ||
print("Writing result to", args.output_file) | ||
|
||
usbin_write(I, D, args.output_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Sparse dataset for the 2023 ANN challenge | ||
|
||
## Goal | ||
This is a dataset of sparse vectors. These are vectors of very high dimension (~30k), | ||
but with a small number of nonzero elements. | ||
A typical example is a way to represent text, where the dimension is the vocabulary, | ||
and the values correspond to the different words in each document / paragraph that is indexed. | ||
|
||
## Dataset details | ||
|
||
**Dataset**: sparse embedding of the MS-MARCO passage retrieval dataset. | ||
The embeddings are based on a deep learning model called SPLADE (specifically, it is the | ||
SPLADE CoCondenser EnsembleDistil (`naver/splade-cocondenser-ensembledistil`)). | ||
|
||
The base dataset contains .8M vectors with average sparsity (# of nonzeros): ~130. All nonzero values are positive. | ||
|
||
The common query set (`dev.small`) contains 6980 queries, where the average number of nonzeros is ~49. | ||
|
||
Similarity is measured by max dot-product, and the overall retrieval score is Recall@10. | ||
For scoring the approximate algorithms, we will measure the maximal throughput that is attained, | ||
as long as the recall@10 is at least 90%. | ||
|
||
## Dataset location and format: | ||
|
||
The big-ann-package contains convenience functions for loading the data and ground truth files. | ||
|
||
The dataset, along with smaller versions for development (with their ground truth files) are located in the following location: | ||
|
||
| Name | Description | download link | #rows | ground truth | | ||
|:--------------|----------------------------|----------------------------------------------------------------------------------------------|-----------|-------------------------------------------------------------------------------------------| | ||
| `full` | Full base dataset | [5.5 GB](https://storage.googleapis.com/ann-challenge-sparse-vectors/csr/base_full.csr.gz) | 8,841,823 | [545K](https://storage.googleapis.com/ann-challenge-sparse-vectors/csr/base_full.dev.gt) | | ||
| `1M` | 1M slice of base dataset | [636.3 MB](https://storage.googleapis.com/ann-challenge-sparse-vectors/csr/base_1M.csr.gz) | 1,000,000 | [545K](https://storage.googleapis.com/ann-challenge-sparse-vectors/csr/base_1M.dev.gt) | | ||
| `small` | 100k slice of base dataset | [64.3 MB](https://storage.googleapis.com/ann-challenge-sparse-vectors/csr/base_small.csr.gz) | 100,000 | [545K](https://storage.googleapis.com/ann-challenge-sparse-vectors/csr/base_small.dev.gt) | | ||
| `queries.dev` | queries file | [1.8 MB](https://storage.googleapis.com/ann-challenge-sparse-vectors/csr/queries.dev.csr.gz) | 6,980 | N/A | | ||
|
||
--- | ||
|
||
TODO: | ||
|
||
1. add results of baseline algorithm | ||
2. | ||
|
||
Baseline algorithm | ||
As a baseline algorithm, we propose a basic (but efficient) exact algorithm called linscan. It is based on an inverted index, and can be made faster (and less precise) with an early stopping condition. We (pinecone) can contribute an open source implementation. | ||
|
||
Results of the baseline algorithm: | ||
Llinscan-anytime. Both single-thread and multi-thread: | ||
|
||
TODO (plot throughput/recall). Extract max throughput at 90% recall. | ||
|
||
|
||
Link to open source package (rust with python bindings): | ||
|
||
TODO | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from scipy.sparse import csr_matrix | ||
import numpy as np | ||
|
||
# given a vector x, returns another vector with the minimal number of largest elements of x, | ||
# s.t. their sum is at most a times the sum of the elements in x. | ||
# | ||
# The goal is to sparsify the vector further, | ||
# but at the same time try and preserve as much of the original vector as possible. | ||
def largest_elements(x, a): | ||
# Compute the sum of elements of x | ||
x_sum = np.sum(x) | ||
|
||
# Compute the indices and values of the largest elements of x | ||
ind = np.argsort(-x.data) | ||
cs = np.cumsum(x.data[ind] / x_sum) | ||
|
||
n_elements = min(sum(cs < a) + 1, x.nnz) # rounding errors sometimes results in n_elements > x.nnz | ||
|
||
new_ind = x.indices[ind[:n_elements]] | ||
new_data = x.data[ind[:n_elements]] | ||
return csr_matrix((new_data, new_ind, [0, n_elements]), shape=x.shape) | ||
|
||
|
||
# a basic sparse index. | ||
# methods: | ||
# 1. init: from a csr matrix of data. | ||
# 2. query a singe vector, with parameters: | ||
# - k (# of neighbors), | ||
# - alpha (fraction of the sum of the vector to maintain. alpha=1 is exact search). | ||
class BasicSparseIndex(object): | ||
def __init__(self, data_csr): | ||
self.data_csc = data_csr.tocsc() | ||
|
||
def query(self, q, k, alpha=1): # single query, assumes q is a row vector | ||
if alpha == 1: | ||
q2 = q.transpose() | ||
else: | ||
q2 = largest_elements(q, alpha).transpose() | ||
|
||
# perform (sparse) matrix-vector multiplication | ||
res = self.data_csc.dot(q2) | ||
|
||
if res.nnz <= k: # if there are less than k elements with nonzero score, simply return them | ||
return list(zip(res.indices, res.data)) | ||
|
||
# extract the top k from the res sparse array directly | ||
indices = np.argpartition(res.data, -(k + 1))[-k:] | ||
results = [] | ||
for index in indices: | ||
results.append((res.data[index], index)) | ||
results.sort(reverse=True) | ||
return [(res.indices[b], a) for a, b in results] |
Oops, something went wrong.