Skip to content

Commit 1f8fdec

Browse files
Merge pull request #757 from soerenreichardt/configurable-logging
Configure progress logging on GDS object
2 parents 3b5e45c + c944019 commit 1f8fdec

11 files changed

+122
-15
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* Add `ttl` parameter to `GdsSessions.get_or_create` to control if and when an unused session will be automatically deleted.
1414
* Add concurrency control for remote write-back procedures using the `concurrency` parameter.
1515
* Add progress logging for remote write-back when using GDS Sessions.
16+
* Added a flag to GraphDataScience and AuraGraphDataScience classes to disable displaying progress bars when running procedures.
1617

1718
## Bug fixes
1819

graphdatascience/graph_data_science.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
arrow_disable_server_verification: bool = True,
3535
arrow_tls_root_certs: Optional[bytes] = None,
3636
bookmarks: Optional[Any] = None,
37+
show_progress: bool = True,
3738
):
3839
"""
3940
Construct a new GraphDataScience object.
@@ -63,14 +64,16 @@ def __init__(
6364
GDS Arrow Flight server.
6465
bookmarks : Optional[Any], default None
6566
The Neo4j bookmarks to require a certain state before the next query gets executed.
67+
show_progress : bool, default True
68+
A flag to indicate whether to show progress bars for running procedures.
6669
"""
6770
if aura_ds:
6871
GraphDataScience._validate_endpoint(endpoint)
6972

7073
if isinstance(endpoint, QueryRunner):
7174
self._query_runner = endpoint
7275
else:
73-
self._query_runner = Neo4jQueryRunner.create(endpoint, auth, aura_ds, database, bookmarks)
76+
self._query_runner = Neo4jQueryRunner.create(endpoint, auth, aura_ds, database, bookmarks, show_progress)
7477

7578
self._server_version = self._query_runner.server_version()
7679

@@ -86,6 +89,7 @@ def __init__(
8689
None if arrow is True else arrow,
8790
)
8891

92+
self._query_runner.set_show_progress(show_progress)
8993
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)
9094

9195
@property
@@ -129,6 +133,17 @@ def set_bookmarks(self, bookmarks: Any) -> None:
129133
"""
130134
self._query_runner.set_bookmarks(bookmarks)
131135

136+
def set_show_progress(self, show_progress: bool) -> None:
137+
"""
138+
Set whether to show progress for running procedures.
139+
140+
Parameters
141+
----------
142+
show_progress: bool
143+
Whether to show progress for procedures.
144+
"""
145+
self._query_runner.set_show_progress(show_progress)
146+
132147
def database(self) -> Optional[str]:
133148
"""
134149
Get the database which queries are run against.

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ def close(self) -> None:
236236
def fallback_query_runner(self) -> QueryRunner:
237237
return self._fallback_query_runner
238238

239+
def set_show_progress(self, show_progress: bool) -> None:
240+
self._fallback_query_runner.set_show_progress(show_progress)
241+
239242
def create_graph_constructor(
240243
self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]]
241244
) -> GraphConstructor:

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def create(
3333
aura_ds: bool = False,
3434
database: Optional[str] = None,
3535
bookmarks: Optional[Any] = None,
36+
show_progress: bool = True,
3637
) -> Neo4jQueryRunner:
3738
if isinstance(endpoint, str):
3839
config: Dict[str, Any] = {"user_agent": f"neo4j-graphdatascience-v{__version__}"}
@@ -51,7 +52,9 @@ def create(
5152
)
5253

5354
elif isinstance(endpoint, neo4j.Driver):
54-
query_runner = Neo4jQueryRunner(endpoint, auto_close=False, bookmarks=bookmarks, database=database)
55+
query_runner = Neo4jQueryRunner(
56+
endpoint, auto_close=False, bookmarks=bookmarks, database=database, show_progress=show_progress
57+
)
5558

5659
else:
5760
raise ValueError(f"Invalid endpoint type: {type(endpoint)}")
@@ -80,6 +83,7 @@ def __init__(
8083
database: Optional[str] = neo4j.DEFAULT_DATABASE,
8184
auto_close: bool = False,
8285
bookmarks: Optional[Any] = None,
86+
show_progress: bool = True,
8387
):
8488
self._driver = driver
8589
self._config = config
@@ -89,6 +93,7 @@ def __init__(
8993
self._bookmarks = bookmarks
9094
self._last_bookmarks: Optional[Any] = None
9195
self._server_version = None
96+
self._show_progress = show_progress
9297
self._progress_logger = QueryProgressLogger(
9398
self.__run_cypher_simplified_for_query_progress_logger, self.server_version
9499
)
@@ -175,12 +180,15 @@ def call_procedure(
175180
def run_cypher_query() -> DataFrame:
176181
return self.run_cypher(query, params, database, custom_error)
177182

178-
if logging:
183+
if self._resolve_show_progress(logging):
179184
job_id = self._progress_logger.extract_or_create_job_id(params)
180185
return self._progress_logger.run_with_progress_logging(run_cypher_query, job_id, database)
181186
else:
182187
return run_cypher_query()
183188

189+
def _resolve_show_progress(self, show_progress: bool) -> bool:
190+
return self._show_progress and show_progress
191+
184192
def server_version(self) -> ServerVersion:
185193
if self._server_version:
186194
return self._server_version
@@ -256,6 +264,9 @@ def create_graph_constructor(
256264
self, graph_name, concurrency, undirected_relationship_types, self.server_version()
257265
)
258266

267+
def set_show_progress(self, show_progress: bool) -> None:
268+
self._show_progress = show_progress
269+
259270
@staticmethod
260271
def handle_driver_exception(session: neo4j.Session, e: Exception) -> None:
261272
reg_gds_hit = re.search(

graphdatascience/query_runner/query_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,9 @@ def bookmarks(self) -> Optional[Any]:
7676
def last_bookmarks(self) -> Optional[Any]:
7777
pass
7878

79+
@abstractmethod
80+
def set_show_progress(self, show_progress: bool) -> None:
81+
pass
82+
7983
def set_server_version(self, _: ServerVersion) -> None:
8084
pass

graphdatascience/query_runner/session_query_runner.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,22 @@ class SessionQueryRunner(QueryRunner):
2424

2525
@staticmethod
2626
def create(
27-
gds_query_runner: QueryRunner, db_query_runner: QueryRunner, arrow_client: GdsArrowClient
27+
gds_query_runner: QueryRunner, db_query_runner: QueryRunner, arrow_client: GdsArrowClient, show_progress: bool
2828
) -> SessionQueryRunner:
29-
return SessionQueryRunner(gds_query_runner, db_query_runner, arrow_client)
29+
return SessionQueryRunner(gds_query_runner, db_query_runner, arrow_client, show_progress)
3030

3131
def __init__(
3232
self,
3333
gds_query_runner: QueryRunner,
3434
db_query_runner: QueryRunner,
3535
arrow_client: GdsArrowClient,
36+
show_progress: bool,
3637
):
3738
self._gds_query_runner = gds_query_runner
3839
self._db_query_runner = db_query_runner
3940
self._gds_arrow_client = arrow_client
4041
self._resolved_protocol_version = ProtocolVersionResolver(db_query_runner).resolve()
42+
self._show_progress = show_progress
4143
self._progress_logger = QueryProgressLogger(
4244
lambda query, database: self._gds_query_runner.run_cypher(query=query, database=database),
4345
self._gds_query_runner.server_version,
@@ -112,6 +114,10 @@ def create_graph_constructor(
112114
) -> GraphConstructor:
113115
return self._gds_query_runner.create_graph_constructor(graph_name, concurrency, undirected_relationship_types)
114116

117+
def set_show_progress(self, show_progress: bool) -> None:
118+
self._show_progress = show_progress
119+
self._gds_query_runner.set_show_progress(show_progress)
120+
115121
def close(self) -> None:
116122
self._gds_arrow_client.close()
117123
self._gds_query_runner.close()
@@ -184,7 +190,7 @@ def _remote_write_back(
184190
def run_write_back():
185191
return write_protocol.run_write_back(self._db_query_runner, write_back_params, yields)
186192

187-
if logging:
193+
if self._resolve_show_progress(logging):
188194
database_write_result = self._progress_logger.run_with_progress_logging(run_write_back, job_id, database)
189195
else:
190196
database_write_result = run_write_back()
@@ -203,6 +209,9 @@ def run_write_back():
203209

204210
return gds_write_result
205211

212+
def _resolve_show_progress(self, show_progress: bool) -> bool:
213+
return self._show_progress and show_progress
214+
206215
def _inject_arrow_config(self, params: Dict[str, Any]) -> None:
207216
host, port = self._gds_arrow_client.connection_info()
208217
token = self._gds_arrow_client.request_token()

graphdatascience/session/aura_graph_data_science.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def create(
3535
arrow_disable_server_verification: bool = False,
3636
arrow_tls_root_certs: Optional[bytes] = None,
3737
bookmarks: Optional[Any] = None,
38+
show_progress: bool = True,
3839
):
3940
# we need to explicitly set this as the default value is None
4041
# database in the session is always neo4j
@@ -43,6 +44,7 @@ def create(
4344
auth=gds_session_connection_info.auth(),
4445
aura_ds=True,
4546
database="neo4j",
47+
show_progress=show_progress,
4648
)
4749

4850
arrow_info = ArrowInfo.create(session_bolt_query_runner)
@@ -66,14 +68,12 @@ def create(
6668
)
6769

6870
db_bolt_query_runner = Neo4jQueryRunner.create(
69-
db_connection_info.uri,
70-
db_connection_info.auth(),
71-
aura_ds=True,
71+
db_connection_info.uri, db_connection_info.auth(), aura_ds=True, show_progress=False
7272
)
7373
db_bolt_query_runner.set_bookmarks(bookmarks)
7474

7575
session_query_runner = SessionQueryRunner.create(
76-
session_arrow_query_runner, db_bolt_query_runner, session_arrow_client
76+
session_arrow_query_runner, db_bolt_query_runner, session_arrow_client, show_progress
7777
)
7878

7979
gds_version = session_bolt_query_runner.server_version()
@@ -159,6 +159,17 @@ def set_bookmarks(self, bookmarks: Any) -> None:
159159
"""
160160
self._query_runner.set_bookmarks(bookmarks)
161161

162+
def set_show_progress(self, show_progress: bool) -> None:
163+
"""
164+
Set whether to show progress for running procedures.
165+
166+
Parameters
167+
----------
168+
show_progress: bool
169+
Whether to show progress for procedures.
170+
"""
171+
self._query_runner.set_show_progress(show_progress)
172+
162173
def database(self) -> Optional[str]:
163174
"""
164175
Get the database which cypher queries are run against.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from neo4j import Driver
2+
from pandas import DataFrame
3+
4+
from graphdatascience import ServerVersion
5+
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
6+
from graphdatascience.query_runner.session_query_runner import SessionQueryRunner
7+
from graphdatascience.tests.unit.conftest import CollectingQueryRunner
8+
from graphdatascience.tests.unit.test_session_query_runner import FakeArrowClient
9+
10+
11+
def test_disabled_progress_logging(neo4j_driver: Driver):
12+
query_runner = Neo4jQueryRunner.create(neo4j_driver, show_progress=False)
13+
assert query_runner._resolve_show_progress(True) is False
14+
assert query_runner._resolve_show_progress(False) is False
15+
16+
17+
def test_enabled_progress_logging(neo4j_driver: Driver):
18+
query_runner = Neo4jQueryRunner.create(neo4j_driver, show_progress=True)
19+
assert query_runner._resolve_show_progress(True) is True
20+
assert query_runner._resolve_show_progress(False) is False
21+
22+
23+
def test_disabled_progress_logging_session(neo4j_driver: Driver):
24+
version = ServerVersion(2, 7, 0)
25+
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
26+
gds_query_runner = CollectingQueryRunner(version)
27+
query_runner = SessionQueryRunner.create(
28+
gds_query_runner,
29+
db_query_runner,
30+
FakeArrowClient(), # type: ignore
31+
show_progress=False,
32+
)
33+
assert query_runner._resolve_show_progress(True) is False
34+
assert query_runner._resolve_show_progress(False) is False
35+
36+
37+
def test_enabled_progress_logging_session(neo4j_driver: Driver):
38+
version = ServerVersion(2, 7, 0)
39+
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
40+
gds_query_runner = CollectingQueryRunner(version)
41+
query_runner = SessionQueryRunner.create(
42+
gds_query_runner,
43+
db_query_runner,
44+
FakeArrowClient(), # type: ignore
45+
show_progress=True,
46+
)
47+
assert query_runner._resolve_show_progress(True) is True
48+
assert query_runner._resolve_show_progress(False) is False

graphdatascience/tests/integration/test_remote_graph_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_remote_write_back_node_similarity(gds_with_cloud_setup: AuraGraphDataSc
7373
G, writeRelationshipType="SIMILAR", writeProperty="score", similarityCutoff=0
7474
)
7575

76-
assert result["relationshipsWritten"] == 4
76+
assert result["relationshipsWritten"] == 2
7777

7878

7979
@pytest.mark.cloud_architecture

graphdatascience/tests/unit/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def bookmarks(self) -> Optional[Any]:
115115
def last_bookmarks(self) -> Optional[Any]:
116116
return None
117117

118+
def set_show_progress(self, show_progress: bool) -> None:
119+
pass
120+
118121
def create_graph_constructor(
119122
self, graph_name: str, concurrency: int, undirected_relationship_types: Optional[List[str]]
120123
) -> GraphConstructor:

graphdatascience/tests/unit/test_session_query_runner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_extracts_parameters_projection_v1() -> None:
2121
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
2222
gds_query_runner = CollectingQueryRunner(version)
2323
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
24-
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore
24+
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore
2525

2626
qr.call_procedure(
2727
endpoint="gds.arrow.project",
@@ -68,6 +68,7 @@ def test_extracts_parameters_projection_v2() -> None:
6868
gds_query_runner,
6969
db_query_runner,
7070
FakeArrowClient(), # type: ignore
71+
True,
7172
)
7273

7374
qr.call_procedure(
@@ -112,7 +113,7 @@ def test_extracts_parameters_algo_write_v1() -> None:
112113
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
113114
gds_query_runner = CollectingQueryRunner(version)
114115
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
115-
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore
116+
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore
116117

117118
qr.call_procedure(endpoint="gds.degree.write", params=CallParameters(graph_name="g", config={"jobId": "my-job"}))
118119

@@ -141,6 +142,7 @@ def test_extracts_parameters_algo_write_v2() -> None:
141142
gds_query_runner,
142143
db_query_runner,
143144
FakeArrowClient(), # type: ignore
145+
True,
144146
)
145147

146148
qr.call_procedure(
@@ -169,7 +171,7 @@ def test_arrow_and_write_configuration() -> None:
169171
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
170172
gds_query_runner = CollectingQueryRunner(version)
171173
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
172-
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore
174+
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore
173175

174176
qr.call_procedure(
175177
endpoint="gds.degree.write",
@@ -206,7 +208,7 @@ def test_arrow_and_write_configuration_graph_write() -> None:
206208
db_query_runner = CollectingQueryRunner(version, result_mock=DataFrame([{"version": "v1"}]))
207209
gds_query_runner = CollectingQueryRunner(version)
208210
gds_query_runner.set__mock_result(DataFrame([{"databaseLocation": "remote"}]))
209-
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient()) # type: ignore
211+
qr = SessionQueryRunner.create(gds_query_runner, db_query_runner, FakeArrowClient(), True) # type: ignore
210212

211213
qr.call_procedure(
212214
endpoint="gds.graph.nodeProperties.write",

0 commit comments

Comments
 (0)