Skip to content

Add sentence embeddings QM #537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions python/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import mgp
import subprocess
import sys


EXCLUDE_PROPERTIES = {"embedding"}

logger: mgp.Logger = mgp.Logger()

@mgp.write_proc
def compute_embeddings(ctx: mgp.ProcCtx, node: mgp.Vertex) -> mgp.Record(embedding_string=str, success=bool):

try:
from sentence_transformers import SentenceTransformer
except ImportError:
# Make sure pip is there
try:
subprocess.check_call([sys.executable, "ensurepip"])
except subprocess.CalledProcessError:
logger.error("Failed to ensure pip is available")
return mgp.Record(embedding_string="", success=False)

# Install the sentence-transformers package
try:
subprocess.check_call([
sys.executable, "pip", "install", "sentence-transformers"
])
from sentence_transformers import SentenceTransformer
except subprocess.CalledProcessError:
logger.error("Failed to install the sentence-transformers package")
return mgp.Record(embedding_string="", success=False)

try:
model = SentenceTransformer("all-MiniLM-L6-v2")
except Exception as e:
logger.error(f"Failed to load model: {e}")
return mgp.Record(embedding_string="", success=False)

try:

for vertex in ctx.graph.vertices:

# Test id: 555 name: Pero last_name: Peric nums: (1, 2, 3) birthday: 1947-07-30 maps: {'day': 30, 'month': 7, 'year': 1947} lap: 0:02:02.000033

#TODO: parametrize the exluded properties
node_data = " ".join(label.name for label in vertex.labels) + " " + " ".join(
f"{key}: {value}"
for key, value in vertex.properties.items()
if key not in EXCLUDE_PROPERTIES
)
# Compute the embedding
node_embedding = model.encode(node_data)

#TODO: parametrize the property name
vertex.properties["embedding"] = node_embedding.tolist()

return mgp.Record(embedding_string=node_data, success=True)

except Exception as e:
# Handle exceptions by returning failure status
logger.error(f"Failed to compute embedding for node: {e}")
return mgp.Record(embedding_string="", success=False)
Loading