Skip to content

Commit ef2bedc

Browse files
committed
Fix a few more tests
1 parent b7acacf commit ef2bedc

File tree

2 files changed

+23
-36
lines changed

2 files changed

+23
-36
lines changed

tests/test_aiohttp.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from dispatch.experimental.durable.registry import clear_functions
2121
from dispatch.function import Arguments, Error, Function, Input, Output, Registry
2222
from dispatch.proto import _any_unpickle as any_unpickle
23+
from dispatch.proto import _pb_any_pickle as any_pickle
2324
from dispatch.sdk.v1 import call_pb2 as call_pb
2425
from dispatch.sdk.v1 import function_pb2 as function_pb
2526
from dispatch.signature import parse_verification_key, public_key_from_pem
@@ -109,22 +110,15 @@ async def my_function(input: Input) -> Output:
109110
http_client = dispatch.test.httpx.Client(httpx.Client(base_url=self.endpoint))
110111
client = EndpointClient(http_client)
111112

112-
pickled = pickle.dumps("Hello World!")
113-
input_any = google.protobuf.any_pb2.Any()
114-
input_any.Pack(google.protobuf.wrappers_pb2.BytesValue(value=pickled))
115-
116113
req = function_pb.RunRequest(
117114
function=my_function.name,
118-
input=input_any,
115+
input=any_pickle("Hello World!"),
119116
)
120117

121118
resp = client.run(req)
122119

123120
self.assertIsInstance(resp, function_pb.RunResponse)
124121

125-
resp.exit.result.output.Unpack(
126-
output_bytes := google.protobuf.wrappers_pb2.BytesValue()
127-
)
128-
output = pickle.loads(output_bytes.value)
122+
output = any_unpickle(resp.exit.result.output)
129123

130124
self.assertEqual(output, "You told me: 'Hello World!' (12 characters)")

tests/test_fastapi.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ def execute(
227227
any = any_pickle(input)
228228
req.input.CopyFrom(any)
229229
if state is not None:
230-
req.poll_result.coroutine_state = state
230+
any = any_pickle(state)
231+
req.poll_result.typed_coroutine_state.CopyFrom(any)
231232
if calls is not None:
232233
for c in calls:
233234
req.poll_result.results.append(c)
@@ -247,10 +248,6 @@ def proto_call(self, call: call_pb.Call) -> call_pb.CallResult:
247248
resp = self.client.run(req)
248249
self.assertIsInstance(resp, function_pb.RunResponse)
249250

250-
# Assert the response is terminal. Good enough until the test client can
251-
# orchestrate coroutines.
252-
self.assertTrue(len(resp.poll.coroutine_state) == 0)
253-
254251
resp.exit.result.correlation_id = call.correlation_id
255252
return resp.exit.result
256253

@@ -317,9 +314,10 @@ async def my_function(input: Input) -> Output:
317314
return Output.value("not reached")
318315

319316
resp = self.execute(my_function, input="cool stuff")
320-
self.assertEqual(b"42", resp.poll.coroutine_state)
317+
state = any_unpickle(resp.poll.typed_coroutine_state)
318+
self.assertEqual(b"42", state)
321319

322-
resp = self.execute(my_function, state=resp.poll.coroutine_state)
320+
resp = self.execute(my_function, state=state)
323321
self.assertEqual("ValueError", resp.exit.result.error.type)
324322
self.assertEqual(
325323
"This input is for a resumed coroutine", resp.exit.result.error.message
@@ -360,32 +358,29 @@ async def coroutine3(input: Input) -> Output:
360358
if input.is_first_call:
361359
counter = input.input
362360
else:
363-
(counter,) = struct.unpack("@i", input.coroutine_state)
361+
counter = input.coroutine_state
364362
counter -= 1
365363
if counter <= 0:
366364
return Output.value("done")
367-
coroutine_state = struct.pack("@i", counter)
368-
return Output.poll(coroutine_state=coroutine_state)
365+
return Output.poll(coroutine_state=counter)
369366

370367
# first call
371368
resp = self.execute(coroutine3, input=4)
372-
state = resp.poll.coroutine_state
373-
self.assertTrue(len(state) > 0)
369+
state = any_unpickle(resp.poll.typed_coroutine_state)
370+
self.assertEqual(state, 3)
374371

375372
# resume, state = 3
376373
resp = self.execute(coroutine3, state=state)
377-
state = resp.poll.coroutine_state
378-
self.assertTrue(len(state) > 0)
374+
state = any_unpickle(resp.poll.typed_coroutine_state)
375+
self.assertEqual(state, 2)
379376

380377
# resume, state = 2
381378
resp = self.execute(coroutine3, state=state)
382-
state = resp.poll.coroutine_state
383-
self.assertTrue(len(state) > 0)
379+
state = any_unpickle(resp.poll.typed_coroutine_state)
380+
self.assertEqual(state, 1)
384381

385382
# resume, state = 1
386383
resp = self.execute(coroutine3, state=state)
387-
state = resp.poll.coroutine_state
388-
self.assertTrue(len(state) == 0)
389384
out = response_output(resp)
390385
self.assertEqual(out, "done")
391386

@@ -399,18 +394,18 @@ async def coroutine_main(input: Input) -> Output:
399394
if input.is_first_call:
400395
text: str = input.input
401396
return Output.poll(
402-
coroutine_state=text.encode(),
397+
coroutine_state=text,
403398
calls=[coro_compute_len._build_primitive_call(text)],
404399
)
405-
text = input.coroutine_state.decode()
400+
text = input.coroutine_state
406401
length = input.call_results[0].output
407402
return Output.value(f"length={length} text='{text}'")
408403

409404
resp = self.execute(coroutine_main, input="cool stuff")
410405

411406
# main saved some state
412-
state = resp.poll.coroutine_state
413-
self.assertTrue(len(state) > 0)
407+
state = any_unpickle(resp.poll.typed_coroutine_state)
408+
self.assertEqual(state, "cool stuff")
414409
# main asks for 1 call to compute_len
415410
self.assertEqual(len(resp.poll.calls), 1)
416411
call = resp.poll.calls[0]
@@ -426,7 +421,6 @@ async def coroutine_main(input: Input) -> Output:
426421
# resume main with the result
427422
resp = self.execute(coroutine_main, state=state, calls=[resp2])
428423
# validate the final result
429-
self.assertTrue(len(resp.poll.coroutine_state) == 0)
430424
out = response_output(resp)
431425
self.assertEqual("length=10 text='cool stuff'", out)
432426

@@ -440,7 +434,7 @@ async def coroutine_main(input: Input) -> Output:
440434
if input.is_first_call:
441435
text: str = input.input
442436
return Output.poll(
443-
coroutine_state=text.encode(),
437+
coroutine_state=text,
444438
calls=[coro_compute_len._build_primitive_call(text)],
445439
)
446440
error = input.call_results[0].error
@@ -452,8 +446,8 @@ async def coroutine_main(input: Input) -> Output:
452446
resp = self.execute(coroutine_main, input="cool stuff")
453447

454448
# main saved some state
455-
state = resp.poll.coroutine_state
456-
self.assertTrue(len(state) > 0)
449+
state = any_unpickle(resp.poll.typed_coroutine_state)
450+
self.assertEqual(state, "cool stuff")
457451
# main asks for 1 call to compute_len
458452
self.assertEqual(len(resp.poll.calls), 1)
459453
call = resp.poll.calls[0]
@@ -466,7 +460,6 @@ async def coroutine_main(input: Input) -> Output:
466460
# resume main with the result
467461
resp = self.execute(coroutine_main, state=state, calls=[resp2])
468462
# validate the final result
469-
self.assertTrue(len(resp.poll.coroutine_state) == 0)
470463
out = response_output(resp)
471464
self.assertEqual(out, "msg=Dead type='type'")
472465

0 commit comments

Comments
 (0)