-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
30 lines (24 loc) · 1.22 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import nmslib
import pandas as pd
from sentence_transformers import SentenceTransformer
import torch
import autoencoder
def search(query):
embedding = albert_model.encode(query)
with torch.no_grad():
output = autoencoder_model(torch.tensor(embedding).to('cuda'))
# Search five nearest neighbours, their index value and cosine distances are returned
idxs, dists = search_index.knnQuery(output.cpu(), k=5)
# Function details for the index value returned are extracted and printed
for idx, dist in zip(idxs, dists):
code = train_dataframe['code'][idx]
url = train_dataframe['url'][idx]
print(f'cosine dist:{dist:.4f} \n {url} \n {code} \n---------------------------------\n')
train_dataframe = pd.read_csv('generated_resources/train_data.csv')
albert_model = SentenceTransformer('bert-base-nli-mean-tokens').to('cuda')
autoencoder_model = autoencoder.AutoEncoder(768, 256).to('cuda')
autoencoder_model.load_state_dict(torch.load('generated_resources/autoencoder_0.pt'))
autoencoder_model.eval()
search_index = nmslib.init(method='hnsw', space='cosinesimil')
search_index.loadIndex('generated_resources/final.nmslib')
search('trains a k nearest neighbors classifier for face recognition')