diff --git a/src/python/paperai/api.py b/src/python/paperai/api.py index aac4090..3e8c909 100644 --- a/src/python/paperai/api.py +++ b/src/python/paperai/api.py @@ -10,6 +10,10 @@ from paperai.query import Query class API(txtai.api.API): + """ + Extended API on top of txtai to return enriched query results. + """ + def search(self, query, request): """ Extends txtai API to enrich results with content. @@ -24,23 +28,23 @@ def search(self, query, request): if self.embeddings: dbfile = os.path.join(self.config["path"], "articles.sqlite") - topn = int(request.query_params.get("topn", 10)) - threshold = float(request.query_params.get("threshold", 0.6)) + limit = self.limit(request.query_params.get("limit")) + threshold = float(request.query_params["threshold"]) if "threshold" in request.query_params else None with sqlite3.connect(dbfile) as db: cur = db.cursor() # Query for best matches - results = Query.search(self.embeddings, cur, query, topn, threshold) + results = Query.search(self.embeddings, cur, query, limit, threshold) # Get results grouped by document - documents = Query.documents(results, topn) + documents = Query.documents(results, limit) articles = [] # Print each result, sorted by max score descending for uid in sorted(documents, key=lambda k: sum([x[0] for x in documents[k]]), reverse=True): - cur.execute("SELECT Title, Published, Publication, Design, Size, Sample, Method, Entry, Id, Reference " + + cur.execute("SELECT Title, Published, Publication, Design, Size, Sample, Method, Entry, Id, Reference " + "FROM articles WHERE id = ?", [uid]) article = cur.fetchone() @@ -54,3 +58,5 @@ def search(self, query, request): articles.append(article) return articles + + return None diff --git a/test/python/testapi.py b/test/python/testapi.py index bee5f54..52cdd1b 100644 --- a/test/python/testapi.py +++ b/test/python/testapi.py @@ -2,7 +2,6 @@ API module tests """ -import hashlib import os import tempfile import unittest @@ -47,25 +46,20 @@ def start(): return client - @classmethod - def setUpClass(cls): + def testSearch(self): """ - Create API client on creation of class. + Test search via API """ # Build embeddings index Index.run(Utils.PATH, Utils.VECTORFILE) - cls.client = TestAPI.start() - - def testSearch(self): - """ - Test search via API - """ + # Connect to test instance + client = TestAPI.start() + # Run search params = urllib.parse.urlencode({"query": "+hypertension ci", "limit": 1}) + results= client.get("search?%s" % params).json() - values = ["%s%s" % (k, v) for k, v in sorted(self.client.get("search?%s" % params).json()[0].items())] - md5 = hashlib.md5(" ".join(values).encode()).hexdigest() - - self.assertEqual(md5, "07ee525ff2b50142c88fb50afcf46582") + # Check number of results + self.assertEqual(len(results), 1)