Skip to content

Commit ba5b4bd

Browse files
committed
Adds transaction logging to state to ensure we only remove the items
functions write Beforehand we had trouble with state manipulation. If we wanted to do a default write for a single-step action, it would not have any way to know what that wrote, versus what was in the state. This adds a simple state log that has a single "flush" operation -- it just keeps track of all operations since the last "flush" call, and returns those. This way, all we have to do is flush before the operation, flush after, and use the "after" results to filter writes so we know which default to apply. This also cleans up a bit of the immutability guarentees -- we were doing a deepcopy on every state update, which has the potential to slow applications down.
1 parent 7a3e145 commit ba5b4bd

File tree

4 files changed

+207
-22
lines changed

4 files changed

+207
-22
lines changed

burr/core/application.py

+49-10
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from burr.core.graph import Graph, GraphBuilder
4040
from burr.core.persistence import BaseStateLoader, BaseStateSaver
41-
from burr.core.state import State
41+
from burr.core.state import State, StateDelta
4242
from burr.core.validation import BASE_ERROR_MESSAGE
4343
from burr.lifecycle.base import LifecycleAdapter
4444
from burr.lifecycle.internal import LifecycleAdapterSet
@@ -83,15 +83,44 @@ def _adjust_single_step_output(output: Union[State, Tuple[dict, State]], action_
8383
_raise_fn_return_validation_error(output, action_name)
8484

8585

86-
def _apply_defaults(state: State, defaults: Dict[str, Any]) -> State:
86+
def _apply_defaults(
87+
state: State,
88+
defaults: Dict[str, Any],
89+
op_list_to_restrict_writes: Optional[List[StateDelta]] = None,
90+
) -> State:
91+
"""Applies default values to the state. This is useful for the cases in which one applies a default value.
92+
93+
:param state: The state object to apply to.
94+
:param defaults: Default values (key/value) to use
95+
:param op_list_to_restrict_writes: The list of operations to restrict writes to, optional.
96+
If this is specified, then it will only apply the defaults to the keys that were written by ops in the op list.
97+
This allows us to track what it has written, and use that to apply defaults.
98+
:return: The state object with the defaults applied.
99+
"""
87100
state_update = {}
88101
state_to_use = state
102+
op_list_writes = None
103+
# In this case we want to restrict to the written sets
104+
if op_list_to_restrict_writes is not None:
105+
op_list_writes = set()
106+
for op in op_list_to_restrict_writes:
107+
op_list_writes.update(op.writes())
108+
89109
# We really don't need to short-circuit but I want to avoid the update function
90110
# So we might as well
91111
if len(defaults) > 0:
92112
for key, value in defaults.items():
93-
if key not in state:
94-
state_update[key] = value
113+
# if we're tracking the op list
114+
# Then we only want to apply deafults
115+
# to keys that have *not* been written to
116+
# This is more restrictive than the next condition
117+
if op_list_writes is not None:
118+
if key not in op_list_writes:
119+
state_update[key] = value
120+
# Otherwise we just apply the defaults to the state itself
121+
else:
122+
if key not in state:
123+
state_update[key] = value
95124
state_to_use = state.update(**state_update)
96125
return state_to_use
97126

@@ -244,12 +273,13 @@ def _run_single_step_action(
244273
:return: The result of running the action, and the new state
245274
"""
246275
# TODO -- guard all reads/writes with a subset of the state
276+
state.flush_op_list()
247277
action.validate_inputs(inputs)
248278
state = _apply_defaults(state, action.default_reads)
249279
result, new_state = _adjust_single_step_output(
250280
action.run_and_update(state, **inputs), action.name
251281
)
252-
new_state = _apply_defaults(new_state, action.default_writes)
282+
new_state = _apply_defaults(new_state, action.default_writes, state.flush_op_list())
253283
_validate_result(result, action.name)
254284
out = result, _state_update(state, new_state)
255285
_validate_result(result, action.name)
@@ -262,6 +292,7 @@ def _run_single_step_streaming_action(
262292
) -> Generator[Tuple[dict, Optional[State]], None, None]:
263293
"""Runs a single step streaming action. This API is internal-facing.
264294
This normalizes + validates the output."""
295+
state.flush_op_list()
265296
action.validate_inputs(inputs)
266297
state = _apply_defaults(state, action.default_reads)
267298
generator = action.stream_run_and_update(state, **inputs)
@@ -284,7 +315,9 @@ def _run_single_step_streaming_action(
284315
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
285316
)
286317
_validate_result(result, action.name)
287-
state_update = _apply_defaults(state_update, action.default_writes)
318+
state_update = _apply_defaults(
319+
state_update, action.default_writes, state_update.flush_op_list()
320+
)
288321
_validate_reducer_writes(action, state_update, action.name)
289322
yield result, state_update
290323

@@ -293,13 +326,14 @@ async def _arun_single_step_action(
293326
action: SingleStepAction, state: State, inputs: Optional[Dict[str, Any]]
294327
) -> Tuple[dict, State]:
295328
"""Runs a single step action in async. See the synchronous version for more details."""
329+
state.flush_op_list()
296330
state_to_use = state
297331
state_to_use = _apply_defaults(state_to_use, action.default_reads)
298332
action.validate_inputs(inputs)
299333
result, new_state = _adjust_single_step_output(
300334
await action.run_and_update(state_to_use, **inputs), action.name
301335
)
302-
new_state = _apply_defaults(new_state, action.default_writes)
336+
new_state = _apply_defaults(new_state, action.default_writes, state.flush_op_list())
303337
_validate_result(result, action.name)
304338
_validate_reducer_writes(action, new_state, action.name)
305339
return result, _state_update(state, new_state)
@@ -309,6 +343,7 @@ async def _arun_single_step_streaming_action(
309343
action: SingleStepStreamingAction, state: State, inputs: Optional[Dict[str, Any]]
310344
) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
311345
"""Runs a single step streaming action in async. See the synchronous version for more details."""
346+
state.flush_op_list()
312347
action.validate_inputs(inputs)
313348
state = _apply_defaults(state, action.default_reads)
314349
generator = action.stream_run_and_update(state, **inputs)
@@ -331,7 +366,7 @@ async def _arun_single_step_streaming_action(
331366
f"statement must be a tuple of (result, state_update). For example, yield dict(foo='bar'), state.update(foo='bar')"
332367
)
333368
_validate_result(result, action.name)
334-
state_update = _apply_defaults(state_update, action.default_writes)
369+
state_update = _apply_defaults(state_update, action.default_writes, state.flush_op_list())
335370
_validate_reducer_writes(action, state_update, action.name)
336371
# TODO -- add guard against zero-length stream
337372
yield result, state_update
@@ -347,6 +382,7 @@ def _run_multi_step_streaming_action(
347382
348383
This peeks ahead by one so we know when this is done (and when to validate).
349384
"""
385+
state.flush_op_list()
350386
action.validate_inputs(inputs)
351387
state = _apply_defaults(state, action.default_reads)
352388
generator = action.stream_run(state, **inputs)
@@ -361,7 +397,7 @@ def _run_multi_step_streaming_action(
361397
yield next_result, None
362398
_validate_result(result, action.name)
363399
state_update = _run_reducer(action, state, result, action.name)
364-
state_update = _apply_defaults(state_update, action.default_writes)
400+
state_update = _apply_defaults(state_update, action.default_writes, state.flush_op_list())
365401
_validate_reducer_writes(action, state_update, action.name)
366402
yield result, state_update
367403

@@ -370,6 +406,7 @@ async def _arun_multi_step_streaming_action(
370406
action: AsyncStreamingAction, state: State, inputs: Optional[Dict[str, Any]]
371407
) -> AsyncGenerator[Tuple[dict, Optional[State]], None]:
372408
"""Runs a multi-step streaming action in async. See the synchronous version for more details."""
409+
state.flush_op_list()
373410
action.validate_inputs(inputs)
374411
state = _apply_defaults(state, action.default_reads)
375412
generator = action.stream_run(state, **inputs)
@@ -384,7 +421,7 @@ async def _arun_multi_step_streaming_action(
384421
yield next_result, None
385422
_validate_result(result, action.name)
386423
state_update = _run_reducer(action, state, result, action.name)
387-
state_update = _apply_defaults(state_update, action.default_writes)
424+
state_update = _apply_defaults(state_update, action.default_writes, state.flush_op_list())
388425
_validate_reducer_writes(action, state_update, action.name)
389426
yield result, state_update
390427

@@ -537,6 +574,7 @@ def _step(
537574
) -> Optional[Tuple[Action, dict, State]]:
538575
"""Internal-facing version of step. This is the same as step, but with an additional
539576
parameter to hide hook execution so async can leverage it."""
577+
self._state.flush_op_list() # Just to be sure, this is internal but we don't want to carry too many around
540578
with self.context:
541579
next_action = self.get_next_action()
542580
if next_action is None:
@@ -668,6 +706,7 @@ async def astep(self, inputs: Dict[str, Any] = None) -> Optional[Tuple[Action, d
668706

669707
async def _astep(self, inputs: Optional[Dict[str, Any]], _run_hooks: bool = True):
670708
# we want to increment regardless of failure
709+
self.state.flush_op_list()
671710
with self.context:
672711
next_action = self.get_next_action()
673712
if next_action is None:

burr/core/state.py

+50-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import importlib
55
import inspect
66
import logging
7-
from typing import Any, Callable, Dict, Iterator, Mapping, Union
7+
from typing import Any, Callable, Dict, Iterator, List, Mapping, Union
88

99
from burr.core import serde
1010

@@ -95,6 +95,11 @@ def writes(self) -> list[str]:
9595
"""Returns the keys that this state delta writes"""
9696
pass
9797

98+
@abc.abstractmethod
99+
def deletes(self) -> list[str]:
100+
"""Returns the keys that this state delta deletes"""
101+
pass
102+
98103
@abc.abstractmethod
99104
def apply_mutate(self, inputs: dict):
100105
"""Applies the state delta to the inputs"""
@@ -117,6 +122,9 @@ def reads(self) -> list[str]:
117122
def writes(self) -> list[str]:
118123
return list(self.values.keys())
119124

125+
def deletes(self) -> list[str]:
126+
return []
127+
120128
def apply_mutate(self, inputs: dict):
121129
inputs.update(self.values)
122130

@@ -137,13 +145,21 @@ def reads(self) -> list[str]:
137145
def writes(self) -> list[str]:
138146
return list(self.values.keys())
139147

148+
def deletes(self) -> list[str]:
149+
return []
150+
140151
def apply_mutate(self, inputs: dict):
141152
for key, value in self.values.items():
142153
if key not in inputs:
143154
inputs[key] = []
144155
if not isinstance(inputs[key], list):
145156
raise ValueError(f"Cannot append to non-list value {key}={inputs[self.key]}")
146-
inputs[key].append(value)
157+
inputs[key] = [
158+
*inputs[key],
159+
value,
160+
] # Not as efficient but safer, so we don't mutate the original list
161+
# we're doing this to avoid a copy.deepcopy() call, so it is already more efficient than it was before
162+
# That said, if one modifies prior values in the list, it is on them, and undefined behavior
147163

148164
def validate(self, input_state: Dict[str, Any]):
149165
incorrect_types = {}
@@ -171,6 +187,9 @@ def reads(self) -> list[str]:
171187
def writes(self) -> list[str]:
172188
return list(self.values.keys())
173189

190+
def deletes(self) -> list[str]:
191+
return []
192+
174193
def validate(self, input_state: Dict[str, Any]):
175194
incorrect_types = {}
176195
for write_key in self.writes():
@@ -201,11 +220,14 @@ def name(cls) -> str:
201220
return "delete"
202221

203222
def reads(self) -> list[str]:
204-
return list(self.keys)
223+
return []
205224

206225
def writes(self) -> list[str]:
207226
return []
208227

228+
def deletes(self) -> list[str]:
229+
return list(self.keys)
230+
209231
def apply_mutate(self, inputs: dict):
210232
for key in self.keys:
211233
inputs.pop(key, None)
@@ -214,19 +236,36 @@ def apply_mutate(self, inputs: dict):
214236
class State(Mapping):
215237
"""An immutable state object. This is the only way to interact with state in Burr."""
216238

217-
def __init__(self, initial_values: Dict[str, Any] = None):
239+
def __init__(self, initial_values: Dict[str, Any] = None, _op_list: list[StateDelta] = None):
218240
if initial_values is None:
219241
initial_values = dict()
220242
self._state = initial_values
243+
self._op_list = _op_list if _op_list is not None else []
244+
self._internal_sequence_id = 0
245+
246+
def flush_op_list(self) -> List[StateDelta]:
247+
"""Flushes the operation list, returning it and clearing it. This is an internal method,
248+
do not use, as it may change."""
249+
op_list = self._op_list
250+
self._op_list = []
251+
return op_list
252+
253+
@property
254+
def op_list(self) -> list[StateDelta]:
255+
"""The list of operations since this was last flushed.
256+
Also an internal property -- do not use, the implementation might change."""
257+
return self._op_list
221258

222259
def apply_operation(self, operation: StateDelta) -> "State":
223260
"""Applies a given operation to the state, returning a new state"""
224-
new_state = copy.deepcopy(self._state) # TODO -- restrict to just the read keys
261+
new_state = copy.copy(self._state) # TODO -- restrict to just the read keys
225262
operation.validate(new_state)
226263
operation.apply_mutate(
227264
new_state
228265
) # todo -- validate that the write keys are the only different ones
229-
return State(new_state)
266+
self._op_list.append(operation)
267+
# we want to carry this on for now
268+
return State(new_state, _op_list=self._op_list)
230269

231270
def get_all(self) -> Dict[str, Any]:
232271
"""Returns the entire state, realize as a dictionary. This is a copy."""
@@ -327,11 +366,14 @@ def wipe(self, delete: list[str] = None, keep: list[str] = None):
327366
def merge(self, other: "State") -> "State":
328367
"""Merges two states together, overwriting the values in self
329368
with those in other."""
330-
return State({**self.get_all(), **other.get_all()})
369+
return State({**self.get_all(), **other.get_all()}, _op_list=self._op_list)
331370

332371
def subset(self, *keys: str, ignore_missing: bool = True) -> "State":
333372
"""Returns a subset of the state, with only the given keys"""
334-
return State({key: self[key] for key in keys if key in self or not ignore_missing})
373+
return State(
374+
{key: self[key] for key in keys if key in self or not ignore_missing},
375+
_op_list=self._op_list,
376+
)
335377

336378
def __getitem__(self, __k: str) -> Any:
337379
return self._state[__k]

docs/reference/state.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
=================
1+
=====
22
State
3-
=================
3+
=====
44

55
Use the state API to manipulate the state of the application.
66

77
.. autoclass:: burr.core.state.State
88
:members:
9+
:exclude-members: op_list, flush_op_list
910

1011
.. automethod:: __init__
1112

0 commit comments

Comments
 (0)