diff --git a/pydatalab/src/pydatalab/routes/v0_1/graphs.py b/pydatalab/src/pydatalab/routes/v0_1/graphs.py index 89529d874..6fc338cbd 100644 --- a/pydatalab/src/pydatalab/routes/v0_1/graphs.py +++ b/pydatalab/src/pydatalab/routes/v0_1/graphs.py @@ -19,6 +19,9 @@ def get_graph_cy_format( hide_collections: bool = True, ): collection_id = request.args.get("collection_id", type=str) + hide_collections = request.args.get( + "hide_collections", default=False, type=lambda v: v.lower() == "true" + ) if item_id is None: if collection_id is not None: @@ -42,40 +45,61 @@ def get_graph_cy_format( else: query = {} all_documents = flask_mongo.db.items.find( - {**query, **get_default_permissions(user_only=False)}, + {**query, **get_default_permissions(user_only=True)}, projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1}, ) node_ids: set[str] = {document["item_id"] for document in all_documents} all_documents.rewind() else: - all_documents = list( + main_item = flask_mongo.db.items.find_one( + { + "item_id": item_id, + **get_default_permissions(user_only=True), + }, + projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1}, + ) + + if not main_item: + return ( + jsonify(status="error", message=f"Item {item_id} not found or no permission"), + 404, + ) + + all_documents = [main_item] + node_ids = {item_id} + + for relationship in main_item.get("relationships", []) or []: + if relationship.get("item_id"): + node_ids.add(relationship["item_id"]) + + incoming_items = list( flask_mongo.db.items.find( { - "$or": [{"item_id": item_id}, {"relationships.item_id": item_id}], + "relationships.item_id": item_id, **get_default_permissions(user_only=False), }, projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1}, ) ) - node_ids = {document["item_id"] for document in all_documents} | { - relationship.get("item_id") - for document in all_documents - for relationship in document.get("relationships", []) - } - if len(node_ids) > 1: - or_query = [{"item_id": id} for id in node_ids if id != item_id] - next_shell = flask_mongo.db.items.find( - { - "$or": or_query, - **get_default_permissions(user_only=False), - }, - projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1}, - ) + for doc in incoming_items: + node_ids.add(doc["item_id"]) - all_documents.extend(next_shell) - node_ids = node_ids | {document["item_id"] for document in all_documents} + all_documents.extend(incoming_items) + + ids_to_fetch = node_ids - {doc["item_id"] for doc in all_documents} + if ids_to_fetch: + referenced_items = list( + flask_mongo.db.items.find( + { + "item_id": {"$in": list(ids_to_fetch)}, + **get_default_permissions(user_only=False), + }, + projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1}, + ) + ) + all_documents.extend(referenced_items) nodes = [] edges = [] @@ -166,12 +190,18 @@ def get_graph_cy_format( } ) - whitelist = {edge["data"]["source"] for edge in edges} | {item_id} + whitelist = {edge["data"]["source"] for edge in edges} | { + edge["data"]["target"] for edge in edges + } + if item_id: + whitelist.add(item_id) nodes = [ node for node in nodes - if node["data"]["type"] in ("samples", "cells") or node["data"]["id"] in whitelist + if node["data"]["type"] in ("samples", "cells") + or node["data"]["id"] in whitelist + or node["data"]["id"].startswith("Collection:") ] return (jsonify(status="success", nodes=nodes, edges=edges), 200) diff --git a/pydatalab/tests/server/test_graph.py b/pydatalab/tests/server/test_graph.py index e916af7b2..b722443e4 100644 --- a/pydatalab/tests/server/test_graph.py +++ b/pydatalab/tests/server/test_graph.py @@ -129,8 +129,8 @@ def test_simple_graph(admin_client): assert len(graph["edges"]) == 2 graph = admin_client.get("/item-graph/parent").json - assert len(graph["nodes"]) == 6 - assert len(graph["edges"]) == 5 + assert len(graph["nodes"]) == 7 + assert len(graph["edges"]) == 8 samples = sample_list.json["responses"] @@ -156,5 +156,5 @@ def test_simple_graph(admin_client): assert len(graph["edges"]) == 4 graph = admin_client.get("/item-graph/parent").json - assert len(graph["nodes"]) == 6 - assert len(graph["edges"]) == 5 + assert len(graph["nodes"]) == 7 + assert len(graph["edges"]) == 10 diff --git a/pydatalab/tests/server/test_item_graph.py b/pydatalab/tests/server/test_item_graph.py index 4f00e40a7..70a8665c8 100644 --- a/pydatalab/tests/server/test_item_graph.py +++ b/pydatalab/tests/server/test_item_graph.py @@ -8,12 +8,12 @@ from pydatalab.models import Sample, StartingMaterial -def test_single_starting_material(admin_client): +def test_single_starting_material(admin_client, client): item_id = "material" material = StartingMaterial(item_id=item_id) - creation = admin_client.post( + creation = client.post( "/new-sample/", json={"new_sample_data": json.loads(material.json())}, ) @@ -21,11 +21,11 @@ def test_single_starting_material(admin_client): assert creation.status_code == 201 # A single material without connections should be ignored - graph = admin_client.get("/item-graph").json + graph = client.get("/item-graph").json assert len(graph["nodes"]) == 0 # Unless it is asked for directly - graph = admin_client.get(f"/item-graph/{item_id}").json + graph = client.get(f"/item-graph/{item_id}").json assert len(graph["nodes"]) == 1 # Now make a sample and connect it to the starting material; check that the @@ -37,13 +37,90 @@ def test_single_starting_material(admin_client): ], ) - creation = admin_client.post( + creation = client.post( "/new-sample/", json={"new_sample_data": json.loads(parent.json())}, ) assert creation.status_code == 201 - graph = admin_client.get("/item-graph").json + graph = client.get("/item-graph").json + assert len(graph["nodes"]) == 2 + assert len(graph["edges"]) == 1 + + # From both the starting material and the sample + graph = client.get(f"/item-graph/{item_id}").json + assert len(graph["nodes"]) == 2 + assert len(graph["edges"]) == 1 + + # From both the starting material and the sample + graph = client.get("/item-graph/parent").json + assert len(graph["nodes"]) == 2 + assert len(graph["edges"]) == 1 + + # Now add a few more samples in a chain and check that only the relevant ones are shown + child = Sample( + item_id="child", + synthesis_constituents=[ + {"item": {"item_id": "parent", "type": "samples"}, "quantity": None} + ], + ) + + creation = client.post( + "/new-sample/", + json={"new_sample_data": json.loads(child.json())}, + ) + + grandchild = Sample( + item_id="grandchild", + synthesis_constituents=[ + {"item": {"item_id": "child", "type": "samples"}, "quantity": None} + ], + ) + + creation = client.post( + "/new-sample/", + json={"new_sample_data": json.loads(grandchild.json())}, + ) + + great_grandchild = Sample( + item_id="great-grandchild", + synthesis_constituents=[ + {"item": {"item_id": "grandchild", "type": "samples"}, "quantity": None} + ], + ) + + creation = client.post( + "/new-sample/", + json={"new_sample_data": json.loads(great_grandchild.json())}, + ) + + graph = client.get("/item-graph").json + assert len(graph["nodes"]) == 5 + assert len(graph["edges"]) == 4 + + # Check for bug where this behaviour was inconsistent between admin and non-admin users + graph = admin_client.get("/item-graph/great-grandchild").json + assert len(graph["nodes"]) == 2 + assert len(graph["edges"]) == 1 + + # Add an admin only item and check that the non-admin user still sees the same graph + admin_great_great_grandchild = Sample( + item_id="admin-great-great-grandchild", + synthesis_constituents=[ + {"item": {"item_id": "great-grandchild", "type": "samples"}, "quantity": None} + ], + ) + + creation = admin_client.post( + "/new-sample/", json={"new_sample_data": json.loads(admin_great_great_grandchild.json())} + ) + + graph = admin_client.get("/item-graph/great-grandchild").json + assert len(graph["nodes"]) == 3 + assert len(graph["edges"]) == 2 + + # Current broken behaviour: non-admin users see the full graph, but not the things they don't have permission for + graph = client.get("/item-graph/great-grandchild").json assert len(graph["nodes"]) == 2 assert len(graph["edges"]) == 1