Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions caveclient/skeletonservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,7 @@ def generate_bulk_skeletons_async(

return estimated_async_time_secs_upper_bound_sum

@_check_version_compatibility(method_constraint=">=0.22.51")
def get_cached_skeletons_bulk(
self,
root_ids: List,
Expand Down Expand Up @@ -1003,6 +1004,12 @@ def get_cached_skeletons_bulk(
assert datastack_name is not None
assert skeleton_version is not None

skeleton_versions = self.get_versions()
if skeleton_version not in skeleton_versions:
raise ValueError(
f"Unknown skeleton version: {skeleton_version}. Valid options: {skeleton_versions}"
)

if len(root_ids) > MAX_BULK_CACHED_SKELETONS:
logging.warning(
f"The number of root_ids exceeds MAX_BULK_CACHED_SKELETONS ({MAX_BULK_CACHED_SKELETONS}). "
Expand All @@ -1022,7 +1029,6 @@ def get_cached_skeletons_bulk(
endpoint_mapping["skeleton_version"] = skeleton_version
endpoint_mapping["output_format"] = server_format
url = self._endpoints["get_cached_skeletons_bulk_as_post"].format_map(endpoint_mapping)
url += f"?verbose_level={verbose_level}"

data = {
"root_ids": root_ids,
Expand Down Expand Up @@ -1055,6 +1061,7 @@ def get_cached_skeletons_bulk(

return {"skeletons": skeletons, "missing": missing, "async_queued": async_queued}

@_check_version_compatibility(method_constraint=">=0.22.51")
def get_skeleton_access_token(
self,
root_ids: List,
Expand Down Expand Up @@ -1109,15 +1116,21 @@ def get_skeleton_access_token(
assert datastack_name is not None
assert skeleton_version is not None

skeleton_versions = self.get_versions()
if skeleton_version not in skeleton_versions:
raise ValueError(
f"Unknown skeleton version: {skeleton_version}. Valid options: {skeleton_versions}"
)

endpoint_mapping = self.default_url_mapping
endpoint_mapping["datastack_name"] = datastack_name
endpoint_mapping["skeleton_version"] = skeleton_version
url = self._endpoints["get_skeleton_token_as_post"].format_map(endpoint_mapping)
url += f"?verbose_level={verbose_level}"

data = {
"root_ids": root_ids,
"expiration_minutes": expiration_minutes,
"verbose_level": verbose_level,
}
response = self.session.post(url, json=data)
return handle_response(response)
Expand Down Expand Up @@ -1156,33 +1169,36 @@ def download_skeletons_with_token(
skeletons = {}

for rid, obj_path in object_paths.items():
encoded_path = urllib.parse.quote(obj_path, safe="")
url = f"https://storage.googleapis.com/download/storage/v1/b/{bucket}/o/{encoded_path}?alt=media"

# Use a direct get with the GCS token (overrides any session auth headers)
resp = self.session.get(url, headers=gcs_headers)
resp.raise_for_status()

# Stored files are gzip-compressed H5
h5_bytes = gzip.decompress(resp.content)

with h5py.File(io.BytesIO(h5_bytes), "r") as f:
sk = {
"vertices": np.array(f["vertices"][()]),
"edges": np.array(f["edges"][()]),
}
if "mesh_to_skel_map" in f:
sk["mesh_to_skel_map"] = np.array(f["mesh_to_skel_map"][()])
if "lvl2_ids" in f:
sk["lvl2_ids"] = np.array(f["lvl2_ids"][()])
if "vertex_properties" in f:
for vp_key in f["vertex_properties"].keys():
sk[vp_key] = np.array(
json.loads(f["vertex_properties"][vp_key][()])
)
if "meta" in f:
sk["meta"] = json.loads(f["meta"][()].tobytes())

skeletons[rid] = sk
try:
encoded_path = urllib.parse.quote(obj_path, safe="")
url = f"https://storage.googleapis.com/download/storage/v1/b/{bucket}/o/{encoded_path}?alt=media"

# Use a direct get with the GCS token (overrides any session auth headers)
resp = self.session.get(url, headers=gcs_headers)
resp.raise_for_status()

# Stored files are gzip-compressed H5
h5_bytes = gzip.decompress(resp.content)

with h5py.File(io.BytesIO(h5_bytes), "r") as f:
sk = {
"vertices": np.array(f["vertices"][()]),
"edges": np.array(f["edges"][()]),
}
if "mesh_to_skel_map" in f:
sk["mesh_to_skel_map"] = np.array(f["mesh_to_skel_map"][()])
if "lvl2_ids" in f:
sk["lvl2_ids"] = np.array(f["lvl2_ids"][()])
if "vertex_properties" in f:
for vp_key in f["vertex_properties"].keys():
sk[vp_key] = np.array(
json.loads(f["vertex_properties"][vp_key][()])
)
if "meta" in f:
sk["meta"] = json.loads(f["meta"][()].tobytes())

skeletons[rid] = sk
except Exception as e:
logging.error(f"Error downloading skeleton for root_id {rid}: {e}")

return skeletons
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ maintainers = [{ name = "CAVE Developers" }]
name = "caveclient"
readme = "README.md"
requires-python = ">=3.9"
version = "8.0.0"
version = "8.1.0"

[project.urls]
Documentation = "https://caveconnectome.github.io/CAVEclient/"
Expand Down Expand Up @@ -73,7 +73,7 @@ test = [
allow_dirty = false
commit = true
commit_args = ""
current_version = "8.0.0"
current_version = "8.1.0"
ignore_missing_version = false
message = "Bump version: {current_version} → {new_version}"
parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)"
Expand Down
Loading