Skip to content

Commit a6156ec

Browse files
feat(#184): safer tokens
* feat(#184): note about remote source script * feat(#184): parameters only in token * feat(#184): remove redundant query obfuscation * feat(#184): don't persist SQL clauses * feat(#184): update docs to match changes
1 parent 18c253e commit a6156ec

File tree

8 files changed

+186
-219
lines changed

8 files changed

+186
-219
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ scripts/run-with-local-http.sh
3131

3232
This project includes a convenience script to index and serve a remote STAC catalog. This script will fully index the remote STAC catalog each time it is run. This may not be the most efficient way to meet your needs, but it does help demonstrate some of this project's capabilities.
3333

34+
> ![NOTE]
35+
> This script should not be used in a production environment. It is intended for local testing only.
36+
3437
This script can optionally be called with a comma-separated list of STAC item JSON fixers, invoking the behaviour described [here](./docs/index-config.md#fixes).
3538

3639
```sh

docs/suitability.md

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -65,44 +65,3 @@ Since 1.1.0 STAC-GeoParquet does not _require_ each collection to exist in a dif
6565
#### stac-fastapi-geoparquet
6666

6767
The [stac-fastapi-geoparquet](https://pypi.org/project/stac-fastapi-geoparquet/) project aims to augment STAC-GeoParquet with a STAC API interface, however this project does not currently appear to offer a production-ready solution.
68-
69-
## Other Considerations
70-
71-
### Paging Tokens
72-
73-
Paging tokens included with search and collection items responses that span multiple pages work differently in this project compared to some other projects. The approach is considered safe, reasonable, and justified, and may be of interest to some users evaluating suitability of this project for a use-case.
74-
75-
The API is ignorant of paging state to help support a serverless deployment, and no state or token information is stored within the API or associated storage. However, the paging process requires some knowledge of state _somewhere_. State is contained within the paging tokens that are received and submitted by the client.
76-
77-
Paging tokens are included in the `next` and `prev` links in a multi-page response. Each token is a [JSON Web Token (JWT)](https://jwt.io/introduction) that provides all information required by the API to progress or regress across pages.
78-
79-
#### SQL
80-
81-
The payload of the JWT includes a parameterised SQL query and parameters that will be used to fetch the next or previous page of results. It also includes the ID of the most recent data load which the API uses to determine if it is paging across a data change, and which supports the behaviour described in the main README's [pagination section](../README.md#pagination).
82-
83-
With a paging token the client provides the API with the SQL query it should execute. In many scenarios this might raise security concerns. See [SQL Safety](#sql-safety) for how such concerns are addressed and negated.
84-
85-
#### SQL Safety
86-
87-
The JWT is secured using the HS256 algorithm and a private key that must be provided at deployment time, and is therefore considered immutable. Hashed JWTs are generally thought safe for verifying identity, and therefore should be capable of preventing SQL query tampering. Integration tests verify that the payload cannot be modified by a client between page requests ([example](https://github.com/sparkgeo/STAC-API-Serverless/blob/212f1a97f091efe19bd6f9edb6084b7f3d508d20/tests/with_environment/integration_tests/test_get_search.py#L201)). If it becomes possible - such as through credential theft - for a malicious actor to modify and re-sign a paging token JWT this is still not considered a significant concern. In a standard deployment the API reads Parquet index files with [read-only](https://github.com/sparkgeo/STAC-API-Serverless/blob/212f1a97f091efe19bd6f9edb6084b7f3d508d20/iac/cdk_deployment/cdk_deployment_stack.py#L68) access to an S3 bucket, which should prevent tampering with the index via SQL. If a malicious user is somehow able to modify files within the API container via a modified SQL query, any changes will be destroyed by the next Lambda invocation.
88-
89-
The JWT private key can be rotated at any time to reduce the risk of credential theft. This change will interrupt any clients actively paging across results, at which point affected clients can reissue their queries.
90-
91-
#### SQL Visibility
92-
93-
This approach exposes the API's SQL query content to clients, however no privileged information can be exposed in this way. The content of the SQL query is comprised entirely of:
94-
1. information that can be gleaned from a review of this repository, and
95-
2. parameters provided by the client.
96-
97-
The API uses placeholders to represent the location of the parquet index files it queries (e.g. an S3 URI) and replaces these immediately prior to SQL execution, so clients have no additional visibility of a deployment's storage infrastructure via a paging token.
98-
99-
The following image shows the content of a sample paging token returned in response to a search query:
100-
101-
```sh
102-
curl -X 'POST' \
103-
'https://host/search' \
104-
-H 'Content-Type: application/json' \
105-
-d '{"collections": ["joplin"]}'
106-
```
107-
108-
![sample paging token jwt](./images/sample-paging-token-jwt.png "Sample Paging Token JWT")

src/stac_fastapi/indexed/db.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import re
21
from datetime import UTC, datetime
32
from logging import Logger, getLogger
43
from os import environ
@@ -16,11 +15,6 @@
1615

1716
_logger: Final[Logger] = getLogger(__name__)
1817
_query_timing_precision: Final[int] = 3
19-
_query_object_identifier_prefix: Final[str] = "src:"
20-
_query_object_identifier_suffix: Final[str] = ":src"
21-
_query_object_identifier_template: Final[str] = (
22-
f"{_query_object_identifier_prefix}{{}}{_query_object_identifier_suffix}"
23-
)
2418

2519
_root_db_connection: DuckDBPyConnection = None
2620
_parquet_uris: Dict[str, str] = {}
@@ -76,18 +70,19 @@ async def disconnect_from_db() -> None:
7670
_logger.error(e)
7771

7872

79-
# SQL queries include placeholder strings that are replaced with Parquet URIs prior to query execution.
80-
# This improves query performance relative to creating views in DuckDB from Parquet files and querying those.
81-
# Placeholders are used until the point of query execution so that API search pagination tokens,
82-
# which are JWT-encoded SQL queries and visible to the client, do not leak implementation detail around
83-
# parquet URI locations.
8473
def format_query_object_name(object_name: str) -> str:
85-
return _query_object_identifier_template.format(object_name)
74+
if object_name in _parquet_uris:
75+
return "'{}'".format(_parquet_uris[object_name])
76+
raise Exception(
77+
"Attempt to use non-existent query object name '{bad_name}'. Available object names: '{availables}'".format(
78+
bad_name=object_name,
79+
availables="', '".join(list(_parquet_uris.keys())),
80+
)
81+
)
8682

8783

8884
def _execute(statement: str, params: Optional[List[Any]] = None) -> None:
8985
start = time()
90-
statement = _prepare_statement(statement)
9186
_get_db_connection().execute(statement, params)
9287
_sql_log_message(statement, time() - start, None, params)
9388

@@ -100,7 +95,6 @@ async def fetchone(
10095
if perform_latest_data_check:
10196
await _ensure_latest_data()
10297
start = time()
103-
statement = _prepare_statement(statement)
10498
result = _get_db_connection().execute(statement, params).fetchone()
10599
_sql_log_message(statement, time() - start, 1 if result is not None else 0, params)
106100
return result
@@ -114,7 +108,6 @@ async def fetchall(
114108
if perform_latest_data_check:
115109
await _ensure_latest_data()
116110
start = time()
117-
statement = _prepare_statement(statement)
118111
result = _get_db_connection().execute(statement, params).fetchall()
119112
_sql_log_message(statement, time() - start, len(result), params)
120113
return result
@@ -130,22 +123,6 @@ def _get_db_connection():
130123
return _root_db_connection.cursor()
131124

132125

133-
def _prepare_statement(statement: str) -> str:
134-
query_object_identifier_regex = rf"\b{re.escape(_query_object_identifier_prefix)}([^:]+){re.escape(_query_object_identifier_suffix)}\b"
135-
for query_object_name in re.findall(query_object_identifier_regex, statement):
136-
if query_object_name not in _parquet_uris:
137-
_logger.warning(
138-
f"{query_object_name} not in parquet URI map, query will likely fail"
139-
)
140-
continue
141-
statement = re.sub(
142-
rf"\b{re.escape(_query_object_identifier_prefix)}{re.escape(query_object_name)}{re.escape(_query_object_identifier_suffix)}\b",
143-
f"'{_parquet_uris[query_object_name]}'",
144-
statement,
145-
)
146-
return statement
147-
148-
149126
def _sql_log_message(
150127
statement: str,
151128
duration: float,
Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,49 @@
1-
from dataclasses import asdict, dataclass, field
2-
from datetime import datetime
3-
from json import JSONEncoder
4-
from re import escape, match
5-
from typing import Any, Dict, Final, List, Optional, Type, cast
1+
from dataclasses import asdict, dataclass, replace
2+
from json import JSONEncoder, loads
3+
from typing import Any, Dict, Final, List, Optional, Self, Type, cast
64

7-
_datetime_field_prefix: Final[str] = "datetime::"
5+
from geojson_pydantic.geometries import parse_geometry_obj
6+
from stac_pydantic.api.extensions.sort import SortExtension
7+
from stac_pydantic.api.search import Intersection
8+
from stac_pydantic.shared import BBox
9+
10+
# Increment this value if query structure changes, so that paging tokens from
11+
# older query structures can be rejected.
12+
current_query_version: Final[int] = 1
813

914

1015
@dataclass(kw_only=True)
1116
class QueryInfo:
12-
query: str
13-
params: List[Any] = field(default_factory=list)
17+
query_version: int
18+
ids: Optional[List[str]] = None
19+
collections: Optional[List[str]] = None
20+
bbox: Optional[BBox] = None
21+
intersects: Optional[Intersection] = None
22+
datetime: Optional[str] = None
23+
filter: Optional[Dict[str, Any]] = None
24+
filter_lang: str
25+
order: Optional[List[SortExtension]] = None
1426
limit: int
1527
offset: Optional[int] = None
1628
last_load_id: str
1729

1830
def next(self) -> "QueryInfo":
19-
return QueryInfo(
20-
query=self.query,
21-
params=self.params,
22-
limit=self.limit,
31+
return replace(
32+
self,
2333
offset=(self.offset + self.limit)
2434
if self.offset is not None
2535
else self.limit,
26-
last_load_id=self.last_load_id,
2736
)
2837

2938
def previous(self) -> "QueryInfo":
3039
# Assume logic of validating that a "previous" link is required (i.e. there is currently a non-None offset) is applied elsewhere.
3140
# Technically we could apply that logic here, but we cannot determine if a "next" link is required in this module, so that would be insconsistent.
3241
current_offset = cast(int, self.offset)
33-
return QueryInfo(
34-
query=self.query,
35-
params=self.params,
36-
limit=self.limit,
42+
return replace(
43+
self,
3744
offset=(current_offset - self.limit)
3845
if current_offset > self.limit
3946
else None,
40-
last_load_id=self.last_load_id,
4147
)
4248

4349
def to_dict(self) -> Dict[str, Any]:
@@ -47,29 +53,22 @@ def to_dict(self) -> Dict[str, Any]:
4753
def json_encoder() -> Type:
4854
return _CustomJSONEncoder
4955

50-
def json_post_decoder(self) -> "QueryInfo":
51-
new_params = []
52-
for param in self.params:
53-
if isinstance(param, str):
54-
datetime_match = match(rf"^{escape(_datetime_field_prefix)}(.+)", param)
55-
if datetime_match:
56-
new_params.append(datetime.fromisoformat(datetime_match.group(1)))
57-
continue
58-
new_params.append(param)
59-
return QueryInfo(
60-
query=self.query,
61-
params=new_params,
62-
limit=self.limit,
63-
offset=self.offset,
64-
last_load_id=self.last_load_id,
56+
def json_post_decoder(self: Self) -> "QueryInfo":
57+
return replace(
58+
self,
59+
intersects=parse_geometry_obj(loads(cast(str, self.intersects)))
60+
if self.intersects is not None
61+
else None,
62+
order=[SortExtension(**loads(cast(str, entry))) for entry in self.order]
63+
if self.order is not None
64+
else None,
6565
)
6666

6767

6868
class _CustomJSONEncoder(JSONEncoder):
6969
def default(self, obj: Any) -> Any:
70-
if isinstance(obj, datetime):
71-
return "{}{}".format(
72-
_datetime_field_prefix,
73-
obj.isoformat(),
74-
)
70+
if isinstance(obj, Intersection):
71+
return cast(Intersection, obj).model_dump_json()
72+
if isinstance(obj, SortExtension):
73+
return cast(SortExtension, obj).model_dump_json()
7574
return JSONEncoder.default(self, obj)

0 commit comments

Comments
 (0)