Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions pydatalab/src/pydatalab/routes/v0_1/graphs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from flask import Blueprint, jsonify, request

from pydatalab.logger import LOGGER
from pydatalab.mongo import flask_mongo
from pydatalab.permissions import active_users_or_get_only, get_default_permissions

Expand All @@ -19,6 +20,9 @@ def get_graph_cy_format(
hide_collections: bool = True,
):
collection_id = request.args.get("collection_id", type=str)
hide_collections = request.args.get(
"hide_collections", default=False, type=lambda v: v.lower() == "true"
)

if item_id is None:
if collection_id is not None:
Expand All @@ -42,28 +46,37 @@ def get_graph_cy_format(
else:
query = {}
all_documents = flask_mongo.db.items.find(
{**query, **get_default_permissions(user_only=False)},
{**query, **get_default_permissions(user_only=True)},
projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1},
)
node_ids: set[str] = {document["item_id"] for document in all_documents}
all_documents.rewind()

else:
LOGGER.debug("!!!!!!!!!!")
all_documents = list(
flask_mongo.db.items.find(
{
"$or": [{"item_id": item_id}, {"relationships.item_id": item_id}],
**get_default_permissions(user_only=False),
"$and": [
{"$or": [{"item_id": item_id}, {"relationships.item_id": item_id}]},
{**get_default_permissions(user_only=False)},
],
},
projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1},
)
)

LOGGER.debug("Found %d documents related to item %s", len(all_documents), item_id)

node_ids = {document["item_id"] for document in all_documents} | {
relationship.get("item_id")
for document in all_documents
for relationship in document.get("relationships", [])
}

LOGGER.debug("Found %d unique node IDs related to item %s", len(node_ids), item_id)
LOGGER.debug("Node IDs: %s", node_ids)

if len(node_ids) > 1:
or_query = [{"item_id": id} for id in node_ids if id != item_id]
next_shell = flask_mongo.db.items.find(
Expand All @@ -74,8 +87,25 @@ def get_graph_cy_format(
projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1},
)

all_documents.extend(next_shell)
node_ids = node_ids | {document["item_id"] for document in all_documents}
all_documents.extend(incoming_items)

ids_to_fetch = node_ids - {doc["item_id"] for doc in all_documents}
if ids_to_fetch:
referenced_items = list(
flask_mongo.db.items.find(
{
"item_id": {"$in": list(ids_to_fetch)},
**get_default_permissions(user_only=False),
},
projection={"item_id": 1, "name": 1, "type": 1, "relationships": 1},
)
)
all_documents.extend(referenced_items)

LOGGER.debug(
"Found %d unique node IDs related to item %s after filtering", len(node_ids), item_id
)
LOGGER.debug("Node IDs: %s", node_ids)

nodes = []
edges = []
Expand Down Expand Up @@ -166,12 +196,16 @@ def get_graph_cy_format(
}
)

whitelist = {edge["data"]["source"] for edge in edges} | {item_id}
whitelist = {edge["data"]["source"] for edge in edges}
if item_id:
whitelist.add(item_id)

nodes = [
node
for node in nodes
if node["data"]["type"] in ("samples", "cells") or node["data"]["id"] in whitelist
if node["data"]["type"] in ("samples", "cells")
or node["data"]["id"] in whitelist
or node["data"]["id"].startswith("Collection:")
]

return (jsonify(status="success", nodes=nodes, edges=edges), 200)
8 changes: 4 additions & 4 deletions pydatalab/tests/server/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def test_simple_graph(admin_client):
assert len(graph["edges"]) == 2

graph = admin_client.get("/item-graph/parent").json
assert len(graph["nodes"]) == 6
assert len(graph["edges"]) == 5
assert len(graph["nodes"]) == 7
assert len(graph["edges"]) == 8

samples = sample_list.json["responses"]

Expand All @@ -156,5 +156,5 @@ def test_simple_graph(admin_client):
assert len(graph["edges"]) == 4

graph = admin_client.get("/item-graph/parent").json
assert len(graph["nodes"]) == 6
assert len(graph["edges"]) == 5
assert len(graph["nodes"]) == 7
assert len(graph["edges"]) == 10
89 changes: 83 additions & 6 deletions pydatalab/tests/server/test_item_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,24 @@
from pydatalab.models import Sample, StartingMaterial


def test_single_starting_material(admin_client):
def test_single_starting_material(admin_client, client):
item_id = "material"

material = StartingMaterial(item_id=item_id)

creation = admin_client.post(
creation = client.post(
"/new-sample/",
json={"new_sample_data": json.loads(material.json())},
)

assert creation.status_code == 201

# A single material without connections should be ignored
graph = admin_client.get("/item-graph").json
graph = client.get("/item-graph").json
assert len(graph["nodes"]) == 0

# Unless it is asked for directly
graph = admin_client.get(f"/item-graph/{item_id}").json
graph = client.get(f"/item-graph/{item_id}").json
assert len(graph["nodes"]) == 1

# Now make a sample and connect it to the starting material; check that the
Expand All @@ -37,13 +37,90 @@ def test_single_starting_material(admin_client):
],
)

creation = admin_client.post(
creation = client.post(
"/new-sample/",
json={"new_sample_data": json.loads(parent.json())},
)

assert creation.status_code == 201

graph = admin_client.get("/item-graph").json
graph = client.get("/item-graph").json
assert len(graph["nodes"]) == 2
assert len(graph["edges"]) == 1

# From both the starting material and the sample
graph = client.get(f"/item-graph/{item_id}").json
assert len(graph["nodes"]) == 2
assert len(graph["edges"]) == 1

# From both the starting material and the sample
graph = client.get("/item-graph/parent").json
assert len(graph["nodes"]) == 2
assert len(graph["edges"]) == 1

# Now add a few more samples in a chain and check that only the relevant ones are shown
child = Sample(
item_id="child",
synthesis_constituents=[
{"item": {"item_id": "parent", "type": "samples"}, "quantity": None}
],
)

creation = client.post(
"/new-sample/",
json={"new_sample_data": json.loads(child.json())},
)

grandchild = Sample(
item_id="grandchild",
synthesis_constituents=[
{"item": {"item_id": "child", "type": "samples"}, "quantity": None}
],
)

creation = client.post(
"/new-sample/",
json={"new_sample_data": json.loads(grandchild.json())},
)

great_grandchild = Sample(
item_id="great-grandchild",
synthesis_constituents=[
{"item": {"item_id": "grandchild", "type": "samples"}, "quantity": None}
],
)

creation = client.post(
"/new-sample/",
json={"new_sample_data": json.loads(great_grandchild.json())},
)

graph = client.get("/item-graph").json
assert len(graph["nodes"]) == 5
assert len(graph["edges"]) == 4

# Check for bug where this behaviour was inconsistent between admin and non-admin users
graph = admin_client.get("/item-graph/great-grandchild").json
assert len(graph["nodes"]) == 2
assert len(graph["edges"]) == 1

# Add an admin only item and check that the non-admin user still sees the same graph
admin_great_great_grandchild = Sample(
item_id="admin-great-great-grandchild",
synthesis_constituents=[
{"item": {"item_id": "great-grandchild", "type": "samples"}, "quantity": None}
],
)

creation = admin_client.post(
"/new-sample/", json={"new_sample_data": json.loads(admin_great_great_grandchild.json())}
)

graph = admin_client.get("/item-graph/great-grandchild").json
assert len(graph["nodes"]) == 3
assert len(graph["edges"]) == 2

# Current broken behaviour: non-admin users see the full graph, but not the things they don't have permission for
graph = client.get("/item-graph/great-grandchild").json
assert len(graph["nodes"]) == 2
assert len(graph["edges"]) == 1
Loading