Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions caveclient/annotationengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import pandas as pd

from .auth import AuthClient
from .base import BaseEncoder, ClientBase, _api_endpoints, handle_response
from .base import BaseEncoder, ClientBase, _api_endpoints, _check_version_compatibility, handle_response
from .endpoints import annotation_api_versions, annotation_common
from .tools import stage
from datetime import datetime

SERVER_KEY = "ae_server_address"

Expand Down Expand Up @@ -96,7 +97,12 @@ def __init__(
def aligned_volume_name(self) -> str:
return self._aligned_volume_name

def get_tables(self, aligned_volume_name: Optional[str] = None) -> list[str]:
@_check_version_compatibility(
kwarg_use_constraints={
"timestamp": ">=4.33.0",
}
)
def get_tables(self, aligned_volume_name: Optional[str] = None, timestamp: Optional[datetime] = None) -> list[str]:
"""Gets a list of table names for a aligned_volume_name

Parameters
Expand All @@ -105,6 +111,9 @@ def get_tables(self, aligned_volume_name: Optional[str] = None) -> list[str]:
Name of the aligned_volume, by default None.
If None, uses the one specified in the client.
Will be set correctly if you are using the framework_client
timestamp: datetime.datetime or None, optional
If set, gets the tables as of that timestamp. By default None.
If None, gets the current tables as of now.

Returns
-------
Expand All @@ -116,7 +125,10 @@ def get_tables(self, aligned_volume_name: Optional[str] = None) -> list[str]:
endpoint_mapping = self.default_url_mapping
endpoint_mapping["aligned_volume_name"] = aligned_volume_name
url = self._endpoints["tables"].format_map(endpoint_mapping)
response = self.session.get(url)
query_d = {}
if timestamp is not None:
query_d["timestamp"] = datetime.strftime(timestamp, "%Y-%m-%dT%H:%M:%S.%f")
response = self.session.get(url, params = query_d)
return handle_response(response)
Comment on lines +128 to 132

def get_annotation_count(
Expand Down
4 changes: 4 additions & 0 deletions caveclient/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,5 +335,9 @@
+ "/{datastack_name}/bulk/gen_skeletons/{skeleton_version}/{root_ids}",
"gen_bulk_skeletons_via_skvn_rids_as_post": skeleton_v1
+ "/{datastack_name}/bulk/gen_skeletons",
"get_cached_skeletons_bulk_as_post": skeleton_v1
+ "/{datastack_name}/bulk/get_cached_skeletons/{skeleton_version}/{output_format}",
"get_skeleton_token_as_post": skeleton_v1
+ "/{datastack_name}/bulk/get_skeleton_token/{skeleton_version}",
}
skeletonservice_api_versions = {1: skeletonservice_endpoints_v1}
248 changes: 248 additions & 0 deletions caveclient/skeletonservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import io
import json
import logging
import urllib.parse
from datetime import datetime, timedelta, timezone
from io import BytesIO, StringIO
from timeit import default_timer
from typing import List, Literal, Optional, Union
Expand All @@ -27,6 +29,7 @@

MAX_SKELETONS_EXISTS_QUERY_SIZE = 1000
MAX_BULK_SYNCHRONOUS_SKELETONS = 10
MAX_BULK_CACHED_SKELETONS = 500 # mirrors server-side MAX_BULK_CACHED_SKELETONS
MAX_BULK_ASYNCHRONOUS_SKELETONS = 10000
BULK_SKELETONS_BATCH_SIZE = 1000

Expand Down Expand Up @@ -84,6 +87,7 @@ def __init__(
)

self._datastack_name = datastack_name
self._gcs_token_cache: dict = {}

def _test_l2cache_exception(self):
raise NoL2CacheException(
Expand Down Expand Up @@ -956,3 +960,247 @@ def generate_bulk_skeletons_async(
)

return estimated_async_time_secs_upper_bound_sum

@_check_version_compatibility(method_constraint=">=0.22.51")
def fetch_skeletons(
self,
root_ids: List,
datastack_name: Optional[str] = None,
skeleton_version: Optional[int] = 4,
output_format: Literal["dict", "swc"] = "dict",
method: Literal["server", "gcs"] = "server",
generate_missing_skeletons: bool = False,
verbose_level: Optional[int] = 0,
) -> dict:
"""Retrieve already-cached skeletons in bulk, up to 500 at a time.

Unlike :meth:`get_bulk_skeletons`, this method:

- Accepts up to 500 root IDs per call (vs the 10-skeleton limit of ``get_bulk_skeletons``)
- Skips per-RID chunkedgraph validation, so it never blocks on network calls
- Never generates skeletons inline; only returns what is already in the cache

Root IDs not found in the cache are simply absent from the returned dict.

Parameters
----------
root_ids : List
Root IDs to retrieve. Truncated to 500 if longer.
datastack_name : str, optional
Datastack name. Defaults to the client's configured datastack.
skeleton_version : int, optional
Skeleton version. Default is 4 (latest).
output_format : "dict" or "swc"
Output format. Default is ``"dict"``. ``method="gcs"`` only supports ``"dict"``.
method : "server" or "gcs"
How to retrieve skeletons.

``"server"`` (default) — POST root IDs to the server; server decodes and returns skeletons.

``"gcs"`` — Obtain a short-lived downscoped GCS token (cached client-side), then
download and parse H5 files directly from the storage bucket, bypassing the service
for data transfer. Significantly faster for large batches. Only supports
``output_format="dict"``.
generate_missing_skeletons : bool
If ``True``, root IDs not found in the cache are queued for async background
generation. They will still be absent from the returned dict. Default ``False``.
verbose_level : int, optional
Verbosity level for server-side logging.

Returns
-------
dict
Mapping of root_id (str) → skeleton object. Only root IDs successfully retrieved
appear in the dict; missing ones are absent.
"""
if datastack_name is None:
datastack_name = self._datastack_name
assert datastack_name is not None
assert skeleton_version is not None

skeleton_versions = self.get_versions()
if skeleton_version not in skeleton_versions:
raise ValueError(
f"Unknown skeleton version: {skeleton_version}. Valid options: {skeleton_versions}"
)

if method == "server":
return self._fetch_skeletons_via_server(
root_ids,
datastack_name,
skeleton_version,
output_format,
generate_missing_skeletons,
verbose_level,
)
elif method == "gcs":
return self._fetch_skeletons_via_gcs(
root_ids,
datastack_name,
skeleton_version,
output_format,
generate_missing_skeletons,
verbose_level,
)
else:
raise ValueError(f"method must be 'server' or 'gcs', got '{method}'")

def _fetch_skeletons_via_server(
self,
root_ids: List,
datastack_name: str,
skeleton_version: int,
output_format: str,
generate_missing_skeletons: bool,
verbose_level: int,
) -> dict:
if len(root_ids) > MAX_BULK_CACHED_SKELETONS:
logging.warning(
f"The number of root_ids exceeds MAX_BULK_CACHED_SKELETONS ({MAX_BULK_CACHED_SKELETONS}). "
f"Only the first {MAX_BULK_CACHED_SKELETONS} will be requested."
)
root_ids = root_ids[:MAX_BULK_CACHED_SKELETONS]

if output_format == "dict":
server_format = "flatdict"
elif output_format == "swc":
server_format = "swccompressed"
else:
raise ValueError(f"output_format must be 'dict' or 'swc', got '{output_format}'")

endpoint_mapping = self.default_url_mapping
endpoint_mapping["datastack_name"] = datastack_name
endpoint_mapping["skeleton_version"] = skeleton_version
endpoint_mapping["output_format"] = server_format
url = self._endpoints["get_cached_skeletons_bulk_as_post"].format_map(endpoint_mapping)

data = {
"root_ids": root_ids,
"generate_missing": generate_missing_skeletons,
"verbose_level": verbose_level,
}
response = self.session.post(url, json=data)
raw = handle_response(response)

skeletons = {}
for rid, encoded in raw.items():
try:
if output_format == "dict":
skeletons[rid] = SkeletonClient.decompressBytesToDict(
io.BytesIO(binascii.unhexlify(encoded)).getvalue()
)
elif output_format == "swc":
sk_csv = io.BytesIO(binascii.unhexlify(encoded)).getvalue().decode()
skeletons[rid] = pd.read_csv(
StringIO(sk_csv),
sep=" ",
names=["id", "type", "x", "y", "z", "radius", "parent"],
)
except Exception as e:
logging.error(f"Error decoding skeleton for root_id {rid}: {e}")

return skeletons

def _fetch_skeletons_via_gcs(
self,
root_ids: List,
datastack_name: str,
skeleton_version: int,
output_format: str,
generate_missing_skeletons: bool,
verbose_level: int,
) -> dict:
if output_format != "dict":
raise ValueError(
"method='gcs' only supports output_format='dict'. "
"Use method='server' for SWC output."
)

token_resp = self._get_or_refresh_gcs_token(datastack_name, skeleton_version, verbose_level)
token = token_resp["token"]
bucket = token_resp["bucket"]
path_template = token_resp["path_template"]

gcs_headers = {"Authorization": f"Bearer {token}"}
skeletons = {}
missing = []

for rid in root_ids:
obj_path = path_template.format(rid=rid)
encoded_path = urllib.parse.quote(obj_path, safe="")
url = (
f"https://storage.googleapis.com/download/storage/v1/b/"
f"{bucket}/o/{encoded_path}?alt=media"
)
try:
resp = self.session.get(url, headers=gcs_headers)
if resp.status_code == 404:
missing.append(rid)
continue
resp.raise_for_status()
skeletons[str(rid)] = SkeletonClient._parse_h5gz_to_dict(resp.content)
except Exception as e:
logging.error(f"Error downloading skeleton for root_id {rid}: {e}")

if generate_missing_skeletons and missing:
try:
self.generate_bulk_skeletons_async(
missing,
datastack_name=datastack_name,
skeleton_version=skeleton_version,
)
except Exception as e:
logging.warning(f"Failed to queue missing skeletons for async generation: {e}")

return skeletons

def _get_or_refresh_gcs_token(
self,
datastack_name: str,
skeleton_version: int,
verbose_level: int = 0,
) -> dict:
cache_key = (datastack_name, skeleton_version)
cached = self._gcs_token_cache.get(cache_key)
if cached is not None:
expiry_str = cached.get("expiry")
if expiry_str:
expiry = datetime.fromisoformat(expiry_str)
if expiry.tzinfo is None:
expiry = expiry.replace(tzinfo=timezone.utc)
if datetime.now(tz=timezone.utc) + timedelta(minutes=5) < expiry:
return cached

endpoint_mapping = self.default_url_mapping
endpoint_mapping["datastack_name"] = datastack_name
endpoint_mapping["skeleton_version"] = skeleton_version
url = self._endpoints["get_skeleton_token_as_post"].format_map(endpoint_mapping)

data = {"verbose_level": verbose_level}
response = self.session.post(url, json=data)
token_resp = handle_response(response)
self._gcs_token_cache[cache_key] = token_resp
return token_resp

@staticmethod
def _parse_h5gz_to_dict(h5gz_bytes: bytes) -> dict:
import h5py

h5_bytes = gzip.decompress(h5gz_bytes)
with h5py.File(io.BytesIO(h5_bytes), "r") as f:
sk = {
"vertices": np.array(f["vertices"][()]),
"edges": np.array(f["edges"][()]),
}
if "mesh_to_skel_map" in f:
sk["mesh_to_skel_map"] = np.array(f["mesh_to_skel_map"][()])
if "lvl2_ids" in f:
sk["lvl2_ids"] = np.array(f["lvl2_ids"][()])
if "vertex_properties" in f:
for vp_key in f["vertex_properties"].keys():
sk[vp_key] = np.array(
json.loads(f["vertex_properties"][vp_key][()])
)
if "meta" in f:
sk["meta"] = json.loads(f["meta"][()].tobytes())
return sk
2 changes: 1 addition & 1 deletion docs/api/skeleton.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ title: client.skeleton
options:
heading_level: 2
show_bases: false
members: ['server_version', 'get_skeleton', 'get_cache_contents', 'skeletons_exist', 'get_bulk_skeletons', 'generate_bulk_skeletons_async']
members: ['server_version', 'get_skeleton', 'get_cache_contents', 'skeletons_exist', 'get_bulk_skeletons', 'generate_bulk_skeletons_async', 'fetch_skeletons']
Comment on lines 8 to +9
7 changes: 7 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
---
title: Changelog
---
## 8.1.0 (March 13, 2026)
- Added `fetch_skeletons()`: unified bulk retrieval of already-cached skeletons with a 500 root ID limit (vs. the 10-skeleton limit of `get_bulk_skeletons()`). Returns a plain `{root_id: skeleton}` dict; missing root IDs are simply absent. Requires server-side SkeletonService >= v0.22.51.
- `method="server"` (default): root IDs are POSTed to the server, which decodes and returns skeletons. Supports both `"dict"` and `"swc"` output formats.
- `method="gcs"`: the client obtains a short-lived downscoped GCS access token (cached client-side, auto-refreshed) and downloads skeleton H5 files directly from the storage bucket, bypassing the service for data transfer. Supports `"dict"` output only. Significantly faster for large batches.
- `generate_missing_skeletons=True`: in either mode, root IDs absent from the cache are queued for asynchronous background generation.
- Added `h5py` as a required dependency (needed for H5 parsing in `method="gcs"`).

## 8.0.0 (November 2, 2025)
- Improved mangling of types from sql queries. Previously, the server side method to read data from PostGres into pandas was via csv streaming, which was caused pandas to infer types. There were cases where this inference was wrong or incomplete. For example if you had a string column, but all your entries for your query happened to be numbers (i.e ["1", "2"]) the result would return those as numbers not strings, but then if your query changed so there was a mix of numbers and strings, those same rows which were numbers would go back to strings (i.e. ["1", "apple"]). Also, boolean columns were being returned as strings "t" or "f".

Expand Down
Loading
Loading