diff --git a/caveclient/skeletonservice.py b/caveclient/skeletonservice.py index d9fb7844..371b7f52 100644 --- a/caveclient/skeletonservice.py +++ b/caveclient/skeletonservice.py @@ -959,6 +959,7 @@ def generate_bulk_skeletons_async( return estimated_async_time_secs_upper_bound_sum + @_check_version_compatibility(method_constraint=">=0.22.51") def get_cached_skeletons_bulk( self, root_ids: List, @@ -1003,6 +1004,12 @@ def get_cached_skeletons_bulk( assert datastack_name is not None assert skeleton_version is not None + skeleton_versions = self.get_versions() + if skeleton_version not in skeleton_versions: + raise ValueError( + f"Unknown skeleton version: {skeleton_version}. Valid options: {skeleton_versions}" + ) + if len(root_ids) > MAX_BULK_CACHED_SKELETONS: logging.warning( f"The number of root_ids exceeds MAX_BULK_CACHED_SKELETONS ({MAX_BULK_CACHED_SKELETONS}). " @@ -1022,7 +1029,6 @@ def get_cached_skeletons_bulk( endpoint_mapping["skeleton_version"] = skeleton_version endpoint_mapping["output_format"] = server_format url = self._endpoints["get_cached_skeletons_bulk_as_post"].format_map(endpoint_mapping) - url += f"?verbose_level={verbose_level}" data = { "root_ids": root_ids, @@ -1055,6 +1061,7 @@ def get_cached_skeletons_bulk( return {"skeletons": skeletons, "missing": missing, "async_queued": async_queued} + @_check_version_compatibility(method_constraint=">=0.22.51") def get_skeleton_access_token( self, root_ids: List, @@ -1109,15 +1116,21 @@ def get_skeleton_access_token( assert datastack_name is not None assert skeleton_version is not None + skeleton_versions = self.get_versions() + if skeleton_version not in skeleton_versions: + raise ValueError( + f"Unknown skeleton version: {skeleton_version}. Valid options: {skeleton_versions}" + ) + endpoint_mapping = self.default_url_mapping endpoint_mapping["datastack_name"] = datastack_name endpoint_mapping["skeleton_version"] = skeleton_version url = self._endpoints["get_skeleton_token_as_post"].format_map(endpoint_mapping) - url += f"?verbose_level={verbose_level}" data = { "root_ids": root_ids, "expiration_minutes": expiration_minutes, + "verbose_level": verbose_level, } response = self.session.post(url, json=data) return handle_response(response) @@ -1156,33 +1169,36 @@ def download_skeletons_with_token( skeletons = {} for rid, obj_path in object_paths.items(): - encoded_path = urllib.parse.quote(obj_path, safe="") - url = f"https://storage.googleapis.com/download/storage/v1/b/{bucket}/o/{encoded_path}?alt=media" - - # Use a direct get with the GCS token (overrides any session auth headers) - resp = self.session.get(url, headers=gcs_headers) - resp.raise_for_status() - - # Stored files are gzip-compressed H5 - h5_bytes = gzip.decompress(resp.content) - - with h5py.File(io.BytesIO(h5_bytes), "r") as f: - sk = { - "vertices": np.array(f["vertices"][()]), - "edges": np.array(f["edges"][()]), - } - if "mesh_to_skel_map" in f: - sk["mesh_to_skel_map"] = np.array(f["mesh_to_skel_map"][()]) - if "lvl2_ids" in f: - sk["lvl2_ids"] = np.array(f["lvl2_ids"][()]) - if "vertex_properties" in f: - for vp_key in f["vertex_properties"].keys(): - sk[vp_key] = np.array( - json.loads(f["vertex_properties"][vp_key][()]) - ) - if "meta" in f: - sk["meta"] = json.loads(f["meta"][()].tobytes()) - - skeletons[rid] = sk + try: + encoded_path = urllib.parse.quote(obj_path, safe="") + url = f"https://storage.googleapis.com/download/storage/v1/b/{bucket}/o/{encoded_path}?alt=media" + + # Use a direct get with the GCS token (overrides any session auth headers) + resp = self.session.get(url, headers=gcs_headers) + resp.raise_for_status() + + # Stored files are gzip-compressed H5 + h5_bytes = gzip.decompress(resp.content) + + with h5py.File(io.BytesIO(h5_bytes), "r") as f: + sk = { + "vertices": np.array(f["vertices"][()]), + "edges": np.array(f["edges"][()]), + } + if "mesh_to_skel_map" in f: + sk["mesh_to_skel_map"] = np.array(f["mesh_to_skel_map"][()]) + if "lvl2_ids" in f: + sk["lvl2_ids"] = np.array(f["lvl2_ids"][()]) + if "vertex_properties" in f: + for vp_key in f["vertex_properties"].keys(): + sk[vp_key] = np.array( + json.loads(f["vertex_properties"][vp_key][()]) + ) + if "meta" in f: + sk["meta"] = json.loads(f["meta"][()].tobytes()) + + skeletons[rid] = sk + except Exception as e: + logging.error(f"Error downloading skeleton for root_id {rid}: {e}") return skeletons diff --git a/pyproject.toml b/pyproject.toml index 349cf944..36f2b6c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ maintainers = [{ name = "CAVE Developers" }] name = "caveclient" readme = "README.md" requires-python = ">=3.9" -version = "8.0.0" +version = "8.1.0" [project.urls] Documentation = "https://caveconnectome.github.io/CAVEclient/" @@ -73,7 +73,7 @@ test = [ allow_dirty = false commit = true commit_args = "" -current_version = "8.0.0" +current_version = "8.1.0" ignore_missing_version = false message = "Bump version: {current_version} → {new_version}" parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)" diff --git a/tests/test_skeletons.py b/tests/test_skeletons.py index eae78003..407c5512 100644 --- a/tests/test_skeletons.py +++ b/tests/test_skeletons.py @@ -591,3 +591,261 @@ def test_generate_bulk_skeletons_async__invalid_skeleton_version( e.args[0] == f"Unknown skeleton version: {skeleton_version}. Valid options: [-1, 0, 1, 2, 3, 4]" ) + + @responses.activate + def test_get_cached_skeletons_bulk__dict(self, myclient, mocker): + sk = { + "meta": {"root_id": 0, "skeleton_version": 4}, + "edges": [[1, 0]], + "mesh_to_skel_map": [0, 1], + "root": 0, + "vertices": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + } + + json_content = { + "skeletons": { + "0": binascii.hexlify( + SkeletonClient.compressDictToBytes(sk) + ).decode("ascii"), + }, + "missing": [1], + "async_queued": [], + } + + bulk_mapping = copy.deepcopy(sk_mapping) + bulk_mapping["output_format"] = "flatdict" + metadata_url = self.sk_endpoints.get( + "get_cached_skeletons_bulk_as_post" + ).format_map(bulk_mapping) + responses.add(responses.POST, url=metadata_url, json=json_content, status=200) + + metadata_url = self.sk_endpoints.get("get_versions").format_map(sk_mapping) + responses.add( + responses.GET, url=metadata_url, json=[-1, 0, 1, 2, 3, 4], status=200 + ) + + result = myclient.skeleton.get_cached_skeletons_bulk([0, 1]) + assert "0" in result["skeletons"] + assert result["missing"] == [1] + assert result["async_queued"] == [] + + @responses.activate + def test_get_cached_skeletons_bulk__swc(self, myclient, mocker): + sk_df = pd.DataFrame( + [[0, 0, 0, 0, 0, 1, -1]], + columns=["id", "type", "x", "y", "z", "radius", "parent"], + ) + sk_csv_str = sk_df.to_csv(index=False, header=False, sep=" ") + encoded = binascii.hexlify(sk_csv_str.encode()).decode("ascii") + + json_content = { + "skeletons": {"0": encoded}, + "missing": [], + "async_queued": [], + } + + bulk_mapping = copy.deepcopy(sk_mapping) + bulk_mapping["output_format"] = "swccompressed" + metadata_url = self.sk_endpoints.get( + "get_cached_skeletons_bulk_as_post" + ).format_map(bulk_mapping) + responses.add(responses.POST, url=metadata_url, json=json_content, status=200) + + metadata_url = self.sk_endpoints.get("get_versions").format_map(sk_mapping) + responses.add( + responses.GET, url=metadata_url, json=[-1, 0, 1, 2, 3, 4], status=200 + ) + + result = myclient.skeleton.get_cached_skeletons_bulk( + [0], output_format="swc" + ) + assert "0" in result["skeletons"] + assert isinstance(result["skeletons"]["0"], pd.DataFrame) + + @responses.activate + def test_get_cached_skeletons_bulk__truncation(self, myclient, mocker): + json_content = { + "skeletons": {}, + "missing": [], + "async_queued": [], + } + + bulk_mapping = copy.deepcopy(sk_mapping) + bulk_mapping["output_format"] = "flatdict" + metadata_url = self.sk_endpoints.get( + "get_cached_skeletons_bulk_as_post" + ).format_map(bulk_mapping) + responses.add(responses.POST, url=metadata_url, json=json_content, status=200) + + metadata_url = self.sk_endpoints.get("get_versions").format_map(sk_mapping) + responses.add( + responses.GET, url=metadata_url, json=[-1, 0, 1, 2, 3, 4], status=200 + ) + + # Should not raise, just truncate silently (with a warning) + result = myclient.skeleton.get_cached_skeletons_bulk(list(range(600))) + assert result == {"skeletons": {}, "missing": [], "async_queued": []} + + @responses.activate + def test_get_cached_skeletons_bulk__invalid_output_format(self, myclient, mocker): + metadata_url = self.sk_endpoints.get("get_versions").format_map(sk_mapping) + responses.add( + responses.GET, url=metadata_url, json=[-1, 0, 1, 2, 3, 4], status=200 + ) + + for output_format in ["", "asdf", "flatdict", "json"]: + try: + myclient.skeleton.get_cached_skeletons_bulk( + [0], output_format=output_format + ) + assert False + except ValueError as e: + assert "output_format must be 'dict' or 'swc'" in e.args[0] + + @responses.activate + def test_get_cached_skeletons_bulk__invalid_skeleton_version( + self, myclient, mocker + ): + metadata_url = self.sk_endpoints.get("get_versions").format_map(sk_mapping) + responses.add( + responses.GET, url=metadata_url, json=[-1, 0, 1, 2, 3, 4], status=200 + ) + + for skeleton_version in [-2, 999]: + try: + myclient.skeleton.get_cached_skeletons_bulk( + [0], skeleton_version=skeleton_version + ) + assert False + except ValueError as e: + assert ( + e.args[0] + == f"Unknown skeleton version: {skeleton_version}. Valid options: [-1, 0, 1, 2, 3, 4]" + ) + + @responses.activate + def test_get_skeleton_access_token(self, myclient, mocker): + token_mapping = copy.deepcopy(sk_mapping) + metadata_url = self.sk_endpoints.get( + "get_skeleton_token_as_post" + ).format_map(token_mapping) + + token_response = { + "token": "ya29.test_token", + "token_type": "Bearer", + "expiry": "2026-03-14T13:00:00Z", + "bucket": "test-bucket", + "object_paths": {"0": "skeletons/v4/0.h5"}, + "missing": [1], + } + responses.add( + responses.POST, url=metadata_url, json=token_response, status=200 + ) + + metadata_url = self.sk_endpoints.get("get_versions").format_map(sk_mapping) + responses.add( + responses.GET, url=metadata_url, json=[-1, 0, 1, 2, 3, 4], status=200 + ) + + result = myclient.skeleton.get_skeleton_access_token([0, 1]) + assert result["token"] == "ya29.test_token" + assert result["bucket"] == "test-bucket" + assert result["missing"] == [1] + assert "0" in result["object_paths"] + + @responses.activate + def test_get_skeleton_access_token__invalid_skeleton_version( + self, myclient, mocker + ): + metadata_url = self.sk_endpoints.get("get_versions").format_map(sk_mapping) + responses.add( + responses.GET, url=metadata_url, json=[-1, 0, 1, 2, 3, 4], status=200 + ) + + for skeleton_version in [-2, 999]: + try: + myclient.skeleton.get_skeleton_access_token( + [0], skeleton_version=skeleton_version + ) + assert False + except ValueError as e: + assert ( + e.args[0] + == f"Unknown skeleton version: {skeleton_version}. Valid options: [-1, 0, 1, 2, 3, 4]" + ) + + @responses.activate + def test_download_skeletons_with_token(self, myclient, mocker): + import gzip + import io + + import h5py + + # Create a minimal gzip-compressed H5 file in memory + h5_buf = io.BytesIO() + with h5py.File(h5_buf, "w") as f: + f.create_dataset("vertices", data=np.array([[1.0, 2.0, 3.0]])) + f.create_dataset("edges", data=np.array([[0, 0]])) + h5_bytes = h5_buf.getvalue() + gz_bytes = gzip.compress(h5_bytes) + + bucket = "test-bucket" + obj_path = "skeletons/v4/0.h5" + encoded_path = "skeletons%2Fv4%2F0.h5" + gcs_url = f"https://storage.googleapis.com/download/storage/v1/b/{bucket}/o/{encoded_path}?alt=media" + + responses.add(responses.GET, url=gcs_url, body=gz_bytes, status=200) + + token_response = { + "token": "ya29.test_token", + "bucket": bucket, + "object_paths": {"0": obj_path}, + } + + result = myclient.skeleton.download_skeletons_with_token(token_response) + assert "0" in result + assert "vertices" in result["0"] + assert "edges" in result["0"] + assert np.array_equal(result["0"]["vertices"], np.array([[1.0, 2.0, 3.0]])) + + @responses.activate + def test_download_skeletons_with_token__error_handling(self, myclient, mocker): + """Test that a failed download for one skeleton doesn't prevent others.""" + import gzip + import io + + import h5py + + # Create a valid gzip-compressed H5 file + h5_buf = io.BytesIO() + with h5py.File(h5_buf, "w") as f: + f.create_dataset("vertices", data=np.array([[1.0, 2.0, 3.0]])) + f.create_dataset("edges", data=np.array([[0, 0]])) + h5_bytes = h5_buf.getvalue() + gz_bytes = gzip.compress(h5_bytes) + + bucket = "test-bucket" + + # First skeleton will fail (404) + obj_path_0 = "skeletons/v4/0.h5" + encoded_path_0 = "skeletons%2Fv4%2F0.h5" + gcs_url_0 = f"https://storage.googleapis.com/download/storage/v1/b/{bucket}/o/{encoded_path_0}?alt=media" + responses.add(responses.GET, url=gcs_url_0, status=404) + + # Second skeleton will succeed + obj_path_1 = "skeletons/v4/1.h5" + encoded_path_1 = "skeletons%2Fv4%2F1.h5" + gcs_url_1 = f"https://storage.googleapis.com/download/storage/v1/b/{bucket}/o/{encoded_path_1}?alt=media" + responses.add(responses.GET, url=gcs_url_1, body=gz_bytes, status=200) + + token_response = { + "token": "ya29.test_token", + "bucket": bucket, + "object_paths": {"0": obj_path_0, "1": obj_path_1}, + } + + result = myclient.skeleton.download_skeletons_with_token(token_response) + # Skeleton 0 should be missing (failed), skeleton 1 should succeed + assert "0" not in result + assert "1" in result + assert np.array_equal(result["1"]["vertices"], np.array([[1.0, 2.0, 3.0]]))