Skip to content

Commit

Permalink
Update API
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Apr 21, 2021
1 parent 55d02f4 commit 9c95f96
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
16 changes: 11 additions & 5 deletions src/python/paperai/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from paperai.query import Query

class API(txtai.api.API):
"""
Extended API on top of txtai to return enriched query results.
"""

def search(self, query, request):
"""
Extends txtai API to enrich results with content.
Expand All @@ -24,23 +28,23 @@ def search(self, query, request):

if self.embeddings:
dbfile = os.path.join(self.config["path"], "articles.sqlite")
topn = int(request.query_params.get("topn", 10))
threshold = float(request.query_params.get("threshold", 0.6))
limit = self.limit(request.query_params.get("limit"))
threshold = float(request.query_params["threshold"]) if "threshold" in request.query_params else None

with sqlite3.connect(dbfile) as db:
cur = db.cursor()

# Query for best matches
results = Query.search(self.embeddings, cur, query, topn, threshold)
results = Query.search(self.embeddings, cur, query, limit, threshold)

# Get results grouped by document
documents = Query.documents(results, topn)
documents = Query.documents(results, limit)

articles = []

# Print each result, sorted by max score descending
for uid in sorted(documents, key=lambda k: sum([x[0] for x in documents[k]]), reverse=True):
cur.execute("SELECT Title, Published, Publication, Design, Size, Sample, Method, Entry, Id, Reference " +
cur.execute("SELECT Title, Published, Publication, Design, Size, Sample, Method, Entry, Id, Reference " +
"FROM articles WHERE id = ?", [uid])
article = cur.fetchone()

Expand All @@ -54,3 +58,5 @@ def search(self, query, request):
articles.append(article)

return articles

return None
22 changes: 8 additions & 14 deletions test/python/testapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
API module tests
"""

import hashlib
import os
import tempfile
import unittest
Expand Down Expand Up @@ -47,25 +46,20 @@ def start():

return client

@classmethod
def setUpClass(cls):
def testSearch(self):
"""
Create API client on creation of class.
Test search via API
"""

# Build embeddings index
Index.run(Utils.PATH, Utils.VECTORFILE)

cls.client = TestAPI.start()

def testSearch(self):
"""
Test search via API
"""
# Connect to test instance
client = TestAPI.start()

# Run search
params = urllib.parse.urlencode({"query": "+hypertension ci", "limit": 1})
results= client.get("search?%s" % params).json()

values = ["%s%s" % (k, v) for k, v in sorted(self.client.get("search?%s" % params).json()[0].items())]
md5 = hashlib.md5(" ".join(values).encode()).hexdigest()

self.assertEqual(md5, "07ee525ff2b50142c88fb50afcf46582")
# Check number of results
self.assertEqual(len(results), 1)

0 comments on commit 9c95f96

Please sign in to comment.