@@ -227,7 +227,8 @@ def execute(
227
227
any = any_pickle (input )
228
228
req .input .CopyFrom (any )
229
229
if state is not None :
230
- req .poll_result .coroutine_state = state
230
+ any = any_pickle (state )
231
+ req .poll_result .typed_coroutine_state .CopyFrom (any )
231
232
if calls is not None :
232
233
for c in calls :
233
234
req .poll_result .results .append (c )
@@ -247,10 +248,6 @@ def proto_call(self, call: call_pb.Call) -> call_pb.CallResult:
247
248
resp = self .client .run (req )
248
249
self .assertIsInstance (resp , function_pb .RunResponse )
249
250
250
- # Assert the response is terminal. Good enough until the test client can
251
- # orchestrate coroutines.
252
- self .assertTrue (len (resp .poll .coroutine_state ) == 0 )
253
-
254
251
resp .exit .result .correlation_id = call .correlation_id
255
252
return resp .exit .result
256
253
@@ -317,9 +314,10 @@ async def my_function(input: Input) -> Output:
317
314
return Output .value ("not reached" )
318
315
319
316
resp = self .execute (my_function , input = "cool stuff" )
320
- self .assertEqual (b"42" , resp .poll .coroutine_state )
317
+ state = any_unpickle (resp .poll .typed_coroutine_state )
318
+ self .assertEqual (b"42" , state )
321
319
322
- resp = self .execute (my_function , state = resp . poll . coroutine_state )
320
+ resp = self .execute (my_function , state = state )
323
321
self .assertEqual ("ValueError" , resp .exit .result .error .type )
324
322
self .assertEqual (
325
323
"This input is for a resumed coroutine" , resp .exit .result .error .message
@@ -360,32 +358,29 @@ async def coroutine3(input: Input) -> Output:
360
358
if input .is_first_call :
361
359
counter = input .input
362
360
else :
363
- ( counter ,) = struct . unpack ( "@i" , input .coroutine_state )
361
+ counter = input .coroutine_state
364
362
counter -= 1
365
363
if counter <= 0 :
366
364
return Output .value ("done" )
367
- coroutine_state = struct .pack ("@i" , counter )
368
- return Output .poll (coroutine_state = coroutine_state )
365
+ return Output .poll (coroutine_state = counter )
369
366
370
367
# first call
371
368
resp = self .execute (coroutine3 , input = 4 )
372
- state = resp .poll .coroutine_state
373
- self .assertTrue ( len ( state ) > 0 )
369
+ state = any_unpickle ( resp .poll .typed_coroutine_state )
370
+ self .assertEqual ( state , 3 )
374
371
375
372
# resume, state = 3
376
373
resp = self .execute (coroutine3 , state = state )
377
- state = resp .poll .coroutine_state
378
- self .assertTrue ( len ( state ) > 0 )
374
+ state = any_unpickle ( resp .poll .typed_coroutine_state )
375
+ self .assertEqual ( state , 2 )
379
376
380
377
# resume, state = 2
381
378
resp = self .execute (coroutine3 , state = state )
382
- state = resp .poll .coroutine_state
383
- self .assertTrue ( len ( state ) > 0 )
379
+ state = any_unpickle ( resp .poll .typed_coroutine_state )
380
+ self .assertEqual ( state , 1 )
384
381
385
382
# resume, state = 1
386
383
resp = self .execute (coroutine3 , state = state )
387
- state = resp .poll .coroutine_state
388
- self .assertTrue (len (state ) == 0 )
389
384
out = response_output (resp )
390
385
self .assertEqual (out , "done" )
391
386
@@ -399,18 +394,18 @@ async def coroutine_main(input: Input) -> Output:
399
394
if input .is_first_call :
400
395
text : str = input .input
401
396
return Output .poll (
402
- coroutine_state = text . encode () ,
397
+ coroutine_state = text ,
403
398
calls = [coro_compute_len ._build_primitive_call (text )],
404
399
)
405
- text = input .coroutine_state . decode ()
400
+ text = input .coroutine_state
406
401
length = input .call_results [0 ].output
407
402
return Output .value (f"length={ length } text='{ text } '" )
408
403
409
404
resp = self .execute (coroutine_main , input = "cool stuff" )
410
405
411
406
# main saved some state
412
- state = resp .poll .coroutine_state
413
- self .assertTrue ( len ( state ) > 0 )
407
+ state = any_unpickle ( resp .poll .typed_coroutine_state )
408
+ self .assertEqual ( state , "cool stuff" )
414
409
# main asks for 1 call to compute_len
415
410
self .assertEqual (len (resp .poll .calls ), 1 )
416
411
call = resp .poll .calls [0 ]
@@ -426,7 +421,6 @@ async def coroutine_main(input: Input) -> Output:
426
421
# resume main with the result
427
422
resp = self .execute (coroutine_main , state = state , calls = [resp2 ])
428
423
# validate the final result
429
- self .assertTrue (len (resp .poll .coroutine_state ) == 0 )
430
424
out = response_output (resp )
431
425
self .assertEqual ("length=10 text='cool stuff'" , out )
432
426
@@ -440,7 +434,7 @@ async def coroutine_main(input: Input) -> Output:
440
434
if input .is_first_call :
441
435
text : str = input .input
442
436
return Output .poll (
443
- coroutine_state = text . encode () ,
437
+ coroutine_state = text ,
444
438
calls = [coro_compute_len ._build_primitive_call (text )],
445
439
)
446
440
error = input .call_results [0 ].error
@@ -452,8 +446,8 @@ async def coroutine_main(input: Input) -> Output:
452
446
resp = self .execute (coroutine_main , input = "cool stuff" )
453
447
454
448
# main saved some state
455
- state = resp .poll .coroutine_state
456
- self .assertTrue ( len ( state ) > 0 )
449
+ state = any_unpickle ( resp .poll .typed_coroutine_state )
450
+ self .assertEqual ( state , "cool stuff" )
457
451
# main asks for 1 call to compute_len
458
452
self .assertEqual (len (resp .poll .calls ), 1 )
459
453
call = resp .poll .calls [0 ]
@@ -466,7 +460,6 @@ async def coroutine_main(input: Input) -> Output:
466
460
# resume main with the result
467
461
resp = self .execute (coroutine_main , state = state , calls = [resp2 ])
468
462
# validate the final result
469
- self .assertTrue (len (resp .poll .coroutine_state ) == 0 )
470
463
out = response_output (resp )
471
464
self .assertEqual (out , "msg=Dead type='type'" )
472
465
0 commit comments