Skip to content

Commit f1d0e2e

Browse files
WIP: fix tests blocked reading http request body
Signed-off-by: Achille Roussel <[email protected]>
1 parent 4c0ff4f commit f1d0e2e

File tree

6 files changed

+110
-8
lines changed

6 files changed

+110
-8
lines changed

.dockerignore

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Dockerfile
2+
__pycache__
3+
*.md
4+
*.yaml
5+
*.yml
6+
dist/*

Dockerfile

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
FROM python:3.12
2+
WORKDIR /usr/src/dispatch-py
3+
4+
COPY pyproject.toml .
5+
RUN python -m pip install -e .[dev]
6+
7+
COPY . .
8+
RUN python -m pip install -e .[dev]
9+
10+
ENTRYPOINT ["python"]

src/dispatch/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import os
66
from concurrent import futures
7-
from http.server import HTTPServer
7+
from http.server import ThreadingHTTPServer
88
from typing import Any, Callable, Coroutine, Optional, TypeVar, overload
99
from urllib.parse import urlsplit
1010

@@ -79,7 +79,8 @@ def run(port: str = os.environ.get("DISPATCH_ENDPOINT_ADDR", "localhost:8000")):
7979
DISPATCH_ENDPOINT_ADDR environment variable, or 'localhost:8000' if it
8080
wasn't set.
8181
"""
82+
print(f"Starting Dispatch server on {port}")
8283
parsed_url = urlsplit("//" + port)
8384
server_address = (parsed_url.hostname or "", parsed_url.port or 0)
84-
server = HTTPServer(server_address, Dispatch(_default_registry()))
85+
server = ThreadingHTTPServer(server_address, Dispatch(_default_registry()))
8586
server.serve_forever()

src/dispatch/http.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Integration of Dispatch functions with http."""
22

3+
from datetime import datetime
4+
35
import logging
46
import os
57
from datetime import timedelta
@@ -61,10 +63,12 @@ def __init__(
6163
registry: Registry,
6264
verification_key: Optional[Ed25519PublicKey] = None,
6365
):
64-
super().__init__(request, client_address, server)
6566
self.registry = registry
6667
self.verification_key = verification_key
6768
self.error_content_type = "application/json"
69+
print(datetime.now(), "INITIALIZING FUNCTION SERVICE")
70+
super().__init__(request, client_address, server)
71+
print(datetime.now(), "DONE HANDLING REQUEST")
6872

6973
def send_error_response_invalid_argument(self, message: str):
7074
self.send_error_response(400, "invalid_argument", message)
@@ -82,17 +86,33 @@ def send_error_response_internal(self, message: str):
8286
self.send_error_response(500, "internal", message)
8387

8488
def send_error_response(self, status: int, code: str, message: str):
89+
body = f'{{"code":"{code}","message":"{message}"}}'.encode()
8590
self.send_response(status)
8691
self.send_header("Content-Type", self.error_content_type)
92+
self.send_header("Content-Length", str(len(body)))
8793
self.end_headers()
88-
self.wfile.write(f'{{"code":"{code}","message":"{message}"}}'.encode())
94+
print(datetime.now(), "SENDING ERROR RESPONSE")
95+
self.wfile.write(body)
96+
print(datetime.now(), f"SERVER IS DONE {len(body)}")
8997

9098
def do_POST(self):
9199
if self.path != "/dispatch.sdk.v1.FunctionService/Run":
92100
self.send_error_response_not_found("path not found")
93101
return
94102

95-
data: bytes = self.rfile.read()
103+
content_length = int(self.headers.get("Content-Length", 0))
104+
if content_length == 0:
105+
self.send_error_response_invalid_argument("content length is required")
106+
return
107+
if content_length < 0:
108+
self.send_error_response_invalid_argument("content length is negative")
109+
return
110+
if content_length > 16_000_000:
111+
self.send_error_response_invalid_argument("content length is too large")
112+
return
113+
114+
data: bytes = self.rfile.read(content_length)
115+
print(datetime.now(), f"RECEIVED POST REQUEST: {self.path} {len(data)} {self.request_version} {self.headers}")
96116
logger.debug("handling run request with %d byte body", len(data))
97117

98118
if self.verification_key is not None:
@@ -130,7 +150,7 @@ def do_POST(self):
130150
)
131151
return
132152

133-
logger.info("running function '%s'", req.function)
153+
print(datetime.now(), "running function '%s'", req.function)
134154
try:
135155
output = func._primitive_call(Input(req))
136156
except Exception:

tests/test_fastapi.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from fastapi.testclient import TestClient
1515

1616
from dispatch.experimental.durable.registry import clear_functions
17-
from dispatch.fastapi import Dispatch, parse_verification_key
17+
from dispatch.fastapi import Dispatch
1818
from dispatch.function import Arguments, Error, Function, Input, Output
1919
from dispatch.proto import _any_unpickle as any_unpickle
2020
from dispatch.sdk.v1 import call_pb2 as call_pb
2121
from dispatch.sdk.v1 import function_pb2 as function_pb
22-
from dispatch.signature import public_key_from_pem
22+
from dispatch.signature import parse_verification_key, public_key_from_pem
2323
from dispatch.status import Status
2424
from dispatch.test import EndpointClient
2525

tests/test_http.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import base64
2+
import os
3+
import pickle
4+
import struct
5+
import threading
6+
import unittest
7+
from typing import Any
8+
from unittest import mock
9+
10+
import fastapi
11+
import google.protobuf.any_pb2
12+
import google.protobuf.wrappers_pb2
13+
import httpx
14+
from http.server import HTTPServer
15+
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
16+
17+
from dispatch.experimental.durable.registry import clear_functions
18+
from dispatch.http import Dispatch
19+
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
20+
from dispatch.proto import _any_unpickle as any_unpickle
21+
from dispatch.sdk.v1 import call_pb2 as call_pb
22+
from dispatch.sdk.v1 import function_pb2 as function_pb
23+
from dispatch.signature import parse_verification_key, public_key_from_pem
24+
from dispatch.status import Status
25+
from dispatch.test import EndpointClient
26+
27+
public_key_pem = "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\n-----END PUBLIC KEY-----"
28+
public_key_pem2 = "-----BEGIN PUBLIC KEY-----\\nMCowBQYDK2VwAyEAJrQLj5P/89iXES9+vFgrIy29clF9CC/oPPsw3c5D0bs=\\n-----END PUBLIC KEY-----"
29+
public_key = public_key_from_pem(public_key_pem)
30+
public_key_bytes = public_key.public_bytes_raw()
31+
public_key_b64 = base64.b64encode(public_key_bytes)
32+
33+
34+
def create_dispatch_instance(endpoint: str):
35+
return Dispatch(
36+
Registry(
37+
endpoint=endpoint,
38+
api_key="0000000000000000",
39+
api_url="http://127.0.0.1:10000",
40+
),
41+
)
42+
43+
44+
class TestHTTP(unittest.TestCase):
45+
def setUp(self):
46+
self.server_address = ('127.0.0.1', 9999)
47+
self.endpoint = f"http://{self.server_address[0]}:{self.server_address[1]}"
48+
self.client = httpx.Client(timeout=1.0)
49+
self.server = HTTPServer(self.server_address, create_dispatch_instance(self.endpoint))
50+
self.thread = threading.Thread(target=self.server.serve_forever)
51+
self.thread.start()
52+
53+
def tearDown(self):
54+
self.server.shutdown()
55+
self.thread.join(timeout=1.0)
56+
self.client.close()
57+
self.server.server_close()
58+
59+
def test_Dispatch_defaults(self):
60+
print("POST REQUEST", f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run")
61+
resp = self.client.post(f"{self.endpoint}/dispatch.sdk.v1.FunctionService/Run")
62+
print(resp.status_code)
63+
print("CLIENT RESPONSE!", resp.headers)
64+
#body = resp.read()
65+
#self.assertEqual(resp.status_code, 400)

0 commit comments

Comments
 (0)