diff --git a/README.md b/README.md index 5be53e2..215a761 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,10 @@ export UPSTREAM_API_URLS=https://stac.eoapi.dev,https://stac.maap-project.org Run the server: ```bash -uv run python -m uvicorn stac_fastapi.collection_discovery.app:app --host 0.0.0.0 --port 8000 +uv run python -m uvicorn stac_fastapi.collection_discovery.app:app \ + --host 0.0.0.0 \ + --port 8000 \ + --reload ``` ### Run the server with Docker diff --git a/src/stac_fastapi/collection_discovery/core.py b/src/stac_fastapi/collection_discovery/core.py index 43a30e2..30f1a56 100644 --- a/src/stac_fastapi/collection_discovery/core.py +++ b/src/stac_fastapi/collection_discovery/core.py @@ -6,7 +6,7 @@ from urllib.parse import unquote, urlencode, urljoin import attr -from fastapi import Query, Request +from fastapi import HTTPException, Query, Request from httpx import AsyncClient from pydantic import BaseModel from stac_pydantic.links import Relations @@ -44,6 +44,33 @@ def _robust_urljoin(base: str, path: str) -> str: HTTPX_TIMEOUT = 15.0 +def _resolve_apis( + apis: list[str] | None, request: Request, allow_empty: bool = False +) -> list[str]: + """Resolve the list of APIs from parameter or app settings. + + Args: + apis: User-provided list of API URLs, or None to use settings + request: FastAPI request object containing app state + allow_empty: If True, return empty list when no APIs found; if False, raise error + + Returns: + List of API URLs + + Raises: + HTTPException: When no APIs are found and allow_empty is False + """ + if not apis: + apis = request.app.state.settings.upstream_api_urls + if not apis and not allow_empty: + raise HTTPException( + status_code=400, + detail="No APIs specified. Provide 'apis' parameter or configure " + "upstream_api_urls in this application.", + ) + return apis or [] + + class UpstreamApiStatus(BaseModel): """Status information for an upstream API.""" @@ -113,13 +140,15 @@ def _build_search_params( } def _get_search_state( - self, token: str | None, apis: list[str], param_str: str + self, token: str | None, apis: list[str] | None, param_str: str ) -> dict[str, Any]: """Get or create search state based on token.""" if token: search_state = self._decode_token(token) logger.info("Continuing collection search with token pagination") else: + if not apis: + raise ValueError("No apis specified") search_state = { "current": { api: _robust_urljoin(api, f"collections?{param_str}") for api in apis @@ -209,10 +238,10 @@ async def all_collections( **kwargs, ) -> Collections: """Collection search for multiple upstream APIs""" - if not apis: - apis = request.app.state.settings.upstream_api_urls - if not apis: - raise ValueError("no apis specified!") + # When using token pagination, apis are encoded in the token + # Only validate apis parameter when not using token + if not token: + apis = _resolve_apis(apis, request) params = self._build_search_params( bbox, datetime, limit, fields, sortby, filter_expr, filter_lang, q @@ -231,10 +260,11 @@ async def all_collections( } async def fetch_api_data( - client, api: str, url: str + client: AsyncClient, api: str, url: str ) -> tuple[str, dict[str, Any]]: """Fetch data from a single API endpoint.""" api_request = await client.get(url) + api_request.raise_for_status() json_response = api_request.json() return api, json_response @@ -313,10 +343,7 @@ async def landing_page( ) # Add upstream APIs as child links - if not apis: - apis = request.app.state.settings.upstream_api_urls - if not apis: - raise ValueError("no apis specified!") + apis = _resolve_apis(apis, request) # include the configured APIs in the description landing_page["description"] = ( @@ -414,10 +441,7 @@ async def conformance_classes( local_conformance_set = set(local_conformance_classes) - if not apis: - apis = request.app.state.settings.upstream_api_urls - if not apis: - raise ValueError("no apis specified!") + apis = _resolve_apis(apis, request) semaphore = asyncio.Semaphore(10) @@ -511,10 +535,7 @@ async def health_check( ] = None, ) -> HealthCheckResponse: """PgSTAC HealthCheck.""" - if not apis: - apis = request.app.state.settings.upstream_api_urls - if not apis: - raise ValueError("no apis specified!") + apis = _resolve_apis(apis, request) upstream_apis: dict[str, UpstreamApiStatus] = {} semaphore = asyncio.Semaphore(10) diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index 2765e4d..e153e02 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -170,6 +170,8 @@ async def test_all_collections_api_error( self, collection_search_client, mock_request, sample_collections_response ): """Test handling of API errors.""" + from httpx import HTTPStatusError + # One API returns error, one succeeds respx.get("https://api1.example.com/collections").mock( return_value=Response(500, json={"error": "Internal server error"}) @@ -179,7 +181,7 @@ async def test_all_collections_api_error( ) # Should raise exception due to failed API call - with pytest.raises(KeyError): + with pytest.raises(HTTPStatusError): await collection_search_client.all_collections(request=mock_request) @pytest.mark.asyncio @@ -243,16 +245,21 @@ async def test_all_collections_with_single_api( @pytest.mark.asyncio async def test_all_collections_empty_apis_parameter(self, collection_search_client): - """Test collection search with empty apis parameter raises ValueError.""" + """Test collection search with empty apis parameter raises HTTPException.""" from unittest.mock import Mock + from fastapi import HTTPException + # Create a mock request with empty upstream_api_urls in settings mock_request = Mock() mock_request.app.state.settings.upstream_api_urls = [] - with pytest.raises(ValueError, match="no apis specified!"): + with pytest.raises(HTTPException) as exc_info: await collection_search_client.all_collections(request=mock_request, apis=[]) + assert exc_info.value.status_code == 400 + assert "No APIs specified" in exc_info.value.detail + @pytest.mark.asyncio async def test_all_collections_no_apis_fallback_to_settings( self, collection_search_client @@ -260,13 +267,18 @@ async def test_all_collections_no_apis_fallback_to_settings( """Test that when no apis parameter provided, it falls back to settings.""" from unittest.mock import Mock + from fastapi import HTTPException + # Create a mock request with empty upstream_api_urls in settings mock_request = Mock() mock_request.app.state.settings.upstream_api_urls = [] - with pytest.raises(ValueError, match="no apis specified!"): + with pytest.raises(HTTPException) as exc_info: await collection_search_client.all_collections(request=mock_request) + assert exc_info.value.status_code == 400 + assert "No APIs specified" in exc_info.value.detail + @pytest.mark.asyncio @respx.mock async def test_all_collections_apis_parameter_with_pagination( @@ -319,6 +331,110 @@ async def test_all_collections_apis_parameter_with_search_params( assert len(result["collections"]) == 2 assert result["numberReturned"] == 2 + @pytest.mark.asyncio + @respx.mock + async def test_apis_parameter_preserved_in_pagination( + self, collection_search_client, mock_request + ): + """Test that apis parameter is preserved when following next link.""" + # Mock response with next link for first page + first_page_response = { + "collections": [ + { + "type": "Collection", + "id": "api3-page1-collection-1", + "title": "Collection Page 1", + "description": "First page collection", + "extent": { + "spatial": {"bbox": [[-180, -90, 180, 90]]}, + "temporal": { + "interval": [["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]] + }, + }, + "license": "MIT", + "links": [], + }, + ], + "links": [ + {"rel": "self", "href": "https://api3.example.com/collections"}, + { + "rel": "next", + "href": "https://api3.example.com/collections?token=page2", + }, + ], + } + + # Mock response for second page + second_page_response = { + "collections": [ + { + "type": "Collection", + "id": "api3-page2-collection-1", + "title": "Collection Page 2", + "description": "Second page collection", + "extent": { + "spatial": {"bbox": [[-180, -90, 180, 90]]}, + "temporal": { + "interval": [["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]] + }, + }, + "license": "MIT", + "links": [], + }, + ], + "links": [ + { + "rel": "self", + "href": "https://api3.example.com/collections?token=page2", + }, + { + "rel": "previous", + "href": "https://api3.example.com/collections", + }, + ], + } + + respx.route(url="https://api3.example.com/collections?token=page2").mock( + return_value=Response(200, json=second_page_response) + ) + respx.route(url="https://api3.example.com/collections").mock( + return_value=Response(200, json=first_page_response) + ) + + # First request with custom apis parameter + custom_apis = ["https://api3.example.com"] + first_result = await collection_search_client.all_collections( + request=mock_request, apis=custom_apis + ) + + # Extract the next link and token from first page + next_link = next( + (link for link in first_result["links"] if link["rel"] == "next"), None + ) + assert next_link is not None + assert "token=" in next_link["href"] + + # Extract token from the next link + token = next_link["href"].split("token=")[1] + + # Decode the token to verify apis are preserved + decoded_token = collection_search_client._decode_token(token) + + # The token should contain the current state with the API URL + assert "current" in decoded_token + assert "https://api3.example.com" in decoded_token["current"] + + # Now follow the next link (simulate second request) + # This should use the token which should have the apis preserved + second_result = await collection_search_client.all_collections( + request=mock_request, token=token + ) + + # Should have collections from page 2 + assert len(second_result["collections"]) == 1 + collection_ids = [c["id"] for c in second_result["collections"]] + assert "api3-page2-collection-1" in collection_ids + @pytest.mark.asyncio async def test_not_implemented_methods(self, collection_search_client): """Test that certain methods raise NotImplementedError."""