Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update query modules signatures #335

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions gqlalchemy/graph_algorithms/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,29 @@ def community_detection_get(
(weight, coloring, min_graph_shrink, community_alg_threshold, coloring_alg_threshold),
)

def community_detection_get_subgraph(
self,
subgraph_nodes: List[Any],
subgraph_relationships: List[Any],
weight: str,
coloring: bool,
min_graph_shrink: int,
community_alg_threshold: float,
coloring_alg_threshold: float,
) -> DeclarativeBase:
return self.call(
"community_detection.get_subgraph",
(
subgraph_nodes,
subgraph_relationships,
weight,
coloring,
min_graph_shrink,
community_alg_threshold,
coloring_alg_threshold,
),
)

def community_detection_online_get(self) -> DeclarativeBase:
return self.call("community_detection_online.get")

Expand Down Expand Up @@ -854,6 +877,58 @@ def kmeans_set_clusters(
(n_clusters, embedding_property, cluster_property, init, n_init, max_iter, tol, algorithm, random_state),
)

def leiden_community_detection_get(
self, weight_property: str, gamma: float, theta: float, resolution_parameter: float, number_of_iterations: int
) -> DeclarativeBase:
return self.call(
"leiden_community_detection.get",
(weight_property, gamma, theta, resolution_parameter, number_of_iterations),
)

def leiden_community_detection_get_subgraph(
self,
subgraph_nodes: List[Any],
subgraph_relationships: List[Any],
weight_property: str,
gamma: float,
theta: float,
resolution_parameter: float,
number_of_iterations: int,
) -> DeclarativeBase:
return self.call(
"leiden_community_detection.get_subgraph",
(
subgraph_nodes,
subgraph_relationships,
weight_property,
gamma,
theta,
resolution_parameter,
number_of_iterations,
),
)

def link_prediction_get_training_results(self) -> DeclarativeBase:
return self.call("link_prediction.get_training_results")

def link_prediction_load_model(self, path: str) -> DeclarativeBase:
return self.call("link_prediction.load_model", (path))

def link_prediction_predict(self, src_vertex: Any, dest_vertex: Any) -> DeclarativeBase:
return self.call("link_prediction.predict", (src_vertex, dest_vertex))

def link_prediction_recommend(self, src_vertex: Any, dest_vertices: List[Any], k: int) -> DeclarativeBase:
return self.call("link_prediction.recommend", (src_vertex, dest_vertices, k))

def link_prediction_reset_parameters(self) -> DeclarativeBase:
return self.call("link_prediction.reset_parameters")

def link_prediction_set_model_parameters(self, parameters: dict) -> DeclarativeBase:
return self.call("link_prediction.set_model_parameters", (parameters))

def link_prediction_train(self) -> DeclarativeBase:
return self.call("link_prediction.train")

def llm_util_schema(self, output_type: str) -> DeclarativeBase:
return self.call("llm_util.schema", (output_type))

Expand Down Expand Up @@ -1099,6 +1174,27 @@ def node2vec_online_set_word2vec_learner(
def node2vec_online_update(self, edges: List[Any]) -> DeclarativeBase:
return self.call("node2vec_online.update", (edges))

def node_classification_get_training_data(self) -> DeclarativeBase:
return self.call("node_classification.get_training_data")

def node_classification_load_model(self, num: int) -> DeclarativeBase:
return self.call("node_classification.load_model", (num))

def node_classification_predict(self, vertex: Any) -> DeclarativeBase:
return self.call("node_classification.predict", (vertex))

def node_classification_reset(self) -> DeclarativeBase:
return self.call("node_classification.reset")

def node_classification_save_model(self) -> DeclarativeBase:
return self.call("node_classification.save_model")

def node_classification_set_model_parameters(self, params: Any) -> DeclarativeBase:
return self.call("node_classification.set_model_parameters", (params))

def node_classification_train(self, num_epochs: int) -> DeclarativeBase:
return self.call("node_classification.train", (num_epochs))

def node_similarity_cosine(self, node1: Any, node2: Any, mode: str = "cartesian") -> DeclarativeBase:
return self.call("node_similarity.cosine", (node1, node2, mode))

Expand Down Expand Up @@ -1294,6 +1390,15 @@ def set_property_copyPropertyRel2Rel(
def temporal_format(self, temporal: Any, format: str) -> DeclarativeBase:
return self.call("temporal.format", (temporal, format))

def text_format(self, text: str, params: List[Any]) -> DeclarativeBase:
return self.call("text.format", (text, params))

def text_join(self, strings: List[str], delimiter: str) -> DeclarativeBase:
return self.call("text.join", (strings, delimiter))

def text_regexGroups(self, input: str, regex: str) -> DeclarativeBase:
return self.call("text.regexGroups", (input, regex))

def text_search_aggregate(self, index_name: str, search_query: str, aggregation_query: str) -> DeclarativeBase:
return self.call("text_search.aggregate", (index_name, search_query, aggregation_query))

Expand Down Expand Up @@ -1350,6 +1455,12 @@ def util_module_md5(self, values: Any) -> DeclarativeBase:
def uuid_generator_get(self) -> DeclarativeBase:
return self.call("uuid_generator.get")

def vector_search_search(self, index_name: str, result_set_size: int, query_vector: List[Any]) -> DeclarativeBase:
return self.call("vector_search.search", (index_name, result_set_size, query_vector))

def vector_search_show_index_info(self) -> DeclarativeBase:
return self.call("vector_search.show_index_info")

def vrp_route(self, depot_node: Any, number_of_vehicles: Optional[int] = None) -> DeclarativeBase:
return self.call("vrp.route", (depot_node, number_of_vehicles))

Expand Down
Loading