38
38
)
39
39
from burr .core .graph import Graph , GraphBuilder
40
40
from burr .core .persistence import BaseStateLoader , BaseStateSaver
41
- from burr .core .state import State
41
+ from burr .core .state import State , StateDelta
42
42
from burr .core .validation import BASE_ERROR_MESSAGE
43
43
from burr .lifecycle .base import LifecycleAdapter
44
44
from burr .lifecycle .internal import LifecycleAdapterSet
@@ -83,15 +83,44 @@ def _adjust_single_step_output(output: Union[State, Tuple[dict, State]], action_
83
83
_raise_fn_return_validation_error (output , action_name )
84
84
85
85
86
- def _apply_defaults (state : State , defaults : Dict [str , Any ]) -> State :
86
+ def _apply_defaults (
87
+ state : State ,
88
+ defaults : Dict [str , Any ],
89
+ op_list_to_restrict_writes : Optional [List [StateDelta ]] = None ,
90
+ ) -> State :
91
+ """Applies default values to the state. This is useful for the cases in which one applies a default value.
92
+
93
+ :param state: The state object to apply to.
94
+ :param defaults: Default values (key/value) to use
95
+ :param op_list_to_restrict_writes: The list of operations to restrict writes to, optional.
96
+ If this is specified, then it will only apply the defaults to the keys that were written by ops in the op list.
97
+ This allows us to track what it has written, and use that to apply defaults.
98
+ :return: The state object with the defaults applied.
99
+ """
87
100
state_update = {}
88
101
state_to_use = state
102
+ op_list_writes = None
103
+ # In this case we want to restrict to the written sets
104
+ if op_list_to_restrict_writes is not None :
105
+ op_list_writes = set ()
106
+ for op in op_list_to_restrict_writes :
107
+ op_list_writes .update (op .writes ())
108
+
89
109
# We really don't need to short-circuit but I want to avoid the update function
90
110
# So we might as well
91
111
if len (defaults ) > 0 :
92
112
for key , value in defaults .items ():
93
- if key not in state :
94
- state_update [key ] = value
113
+ # if we're tracking the op list
114
+ # Then we only want to apply deafults
115
+ # to keys that have *not* been written to
116
+ # This is more restrictive than the next condition
117
+ if op_list_writes is not None :
118
+ if key not in op_list_writes :
119
+ state_update [key ] = value
120
+ # Otherwise we just apply the defaults to the state itself
121
+ else :
122
+ if key not in state :
123
+ state_update [key ] = value
95
124
state_to_use = state .update (** state_update )
96
125
return state_to_use
97
126
@@ -244,12 +273,13 @@ def _run_single_step_action(
244
273
:return: The result of running the action, and the new state
245
274
"""
246
275
# TODO -- guard all reads/writes with a subset of the state
276
+ state .flush_op_list ()
247
277
action .validate_inputs (inputs )
248
278
state = _apply_defaults (state , action .default_reads )
249
279
result , new_state = _adjust_single_step_output (
250
280
action .run_and_update (state , ** inputs ), action .name
251
281
)
252
- new_state = _apply_defaults (new_state , action .default_writes )
282
+ new_state = _apply_defaults (new_state , action .default_writes , state . flush_op_list () )
253
283
_validate_result (result , action .name )
254
284
out = result , _state_update (state , new_state )
255
285
_validate_result (result , action .name )
@@ -262,6 +292,7 @@ def _run_single_step_streaming_action(
262
292
) -> Generator [Tuple [dict , Optional [State ]], None , None ]:
263
293
"""Runs a single step streaming action. This API is internal-facing.
264
294
This normalizes + validates the output."""
295
+ state .flush_op_list ()
265
296
action .validate_inputs (inputs )
266
297
state = _apply_defaults (state , action .default_reads )
267
298
generator = action .stream_run_and_update (state , ** inputs )
@@ -284,7 +315,9 @@ def _run_single_step_streaming_action(
284
315
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
285
316
)
286
317
_validate_result (result , action .name )
287
- state_update = _apply_defaults (state_update , action .default_writes )
318
+ state_update = _apply_defaults (
319
+ state_update , action .default_writes , state_update .flush_op_list ()
320
+ )
288
321
_validate_reducer_writes (action , state_update , action .name )
289
322
yield result , state_update
290
323
@@ -293,13 +326,14 @@ async def _arun_single_step_action(
293
326
action : SingleStepAction , state : State , inputs : Optional [Dict [str , Any ]]
294
327
) -> Tuple [dict , State ]:
295
328
"""Runs a single step action in async. See the synchronous version for more details."""
329
+ state .flush_op_list ()
296
330
state_to_use = state
297
331
state_to_use = _apply_defaults (state_to_use , action .default_reads )
298
332
action .validate_inputs (inputs )
299
333
result , new_state = _adjust_single_step_output (
300
334
await action .run_and_update (state_to_use , ** inputs ), action .name
301
335
)
302
- new_state = _apply_defaults (new_state , action .default_writes )
336
+ new_state = _apply_defaults (new_state , action .default_writes , state . flush_op_list () )
303
337
_validate_result (result , action .name )
304
338
_validate_reducer_writes (action , new_state , action .name )
305
339
return result , _state_update (state , new_state )
@@ -309,6 +343,7 @@ async def _arun_single_step_streaming_action(
309
343
action : SingleStepStreamingAction , state : State , inputs : Optional [Dict [str , Any ]]
310
344
) -> AsyncGenerator [Tuple [dict , Optional [State ]], None ]:
311
345
"""Runs a single step streaming action in async. See the synchronous version for more details."""
346
+ state .flush_op_list ()
312
347
action .validate_inputs (inputs )
313
348
state = _apply_defaults (state , action .default_reads )
314
349
generator = action .stream_run_and_update (state , ** inputs )
@@ -331,7 +366,7 @@ async def _arun_single_step_streaming_action(
331
366
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
332
367
)
333
368
_validate_result (result , action .name )
334
- state_update = _apply_defaults (state_update , action .default_writes )
369
+ state_update = _apply_defaults (state_update , action .default_writes , state . flush_op_list () )
335
370
_validate_reducer_writes (action , state_update , action .name )
336
371
# TODO -- add guard against zero-length stream
337
372
yield result , state_update
@@ -347,6 +382,7 @@ def _run_multi_step_streaming_action(
347
382
348
383
This peeks ahead by one so we know when this is done (and when to validate).
349
384
"""
385
+ state .flush_op_list ()
350
386
action .validate_inputs (inputs )
351
387
state = _apply_defaults (state , action .default_reads )
352
388
generator = action .stream_run (state , ** inputs )
@@ -361,7 +397,7 @@ def _run_multi_step_streaming_action(
361
397
yield next_result , None
362
398
_validate_result (result , action .name )
363
399
state_update = _run_reducer (action , state , result , action .name )
364
- state_update = _apply_defaults (state_update , action .default_writes )
400
+ state_update = _apply_defaults (state_update , action .default_writes , state . flush_op_list () )
365
401
_validate_reducer_writes (action , state_update , action .name )
366
402
yield result , state_update
367
403
@@ -370,6 +406,7 @@ async def _arun_multi_step_streaming_action(
370
406
action : AsyncStreamingAction , state : State , inputs : Optional [Dict [str , Any ]]
371
407
) -> AsyncGenerator [Tuple [dict , Optional [State ]], None ]:
372
408
"""Runs a multi-step streaming action in async. See the synchronous version for more details."""
409
+ state .flush_op_list ()
373
410
action .validate_inputs (inputs )
374
411
state = _apply_defaults (state , action .default_reads )
375
412
generator = action .stream_run (state , ** inputs )
@@ -384,7 +421,7 @@ async def _arun_multi_step_streaming_action(
384
421
yield next_result , None
385
422
_validate_result (result , action .name )
386
423
state_update = _run_reducer (action , state , result , action .name )
387
- state_update = _apply_defaults (state_update , action .default_writes )
424
+ state_update = _apply_defaults (state_update , action .default_writes , state . flush_op_list () )
388
425
_validate_reducer_writes (action , state_update , action .name )
389
426
yield result , state_update
390
427
@@ -537,6 +574,7 @@ def _step(
537
574
) -> Optional [Tuple [Action , dict , State ]]:
538
575
"""Internal-facing version of step. This is the same as step, but with an additional
539
576
parameter to hide hook execution so async can leverage it."""
577
+ self ._state .flush_op_list () # Just to be sure, this is internal but we don't want to carry too many around
540
578
with self .context :
541
579
next_action = self .get_next_action ()
542
580
if next_action is None :
@@ -668,6 +706,7 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d
668
706
669
707
async def _astep (self , inputs : Optional [Dict [str , Any ]], _run_hooks : bool = True ):
670
708
# we want to increment regardless of failure
709
+ self .state .flush_op_list ()
671
710
with self .context :
672
711
next_action = self .get_next_action ()
673
712
if next_action is None :
0 commit comments