11"""Light-weight state machine driving Azazel defensive posture changes."""
22from __future__ import annotations
33
4+ import time
5+ from collections import deque
46from dataclasses import dataclass , field
5- from typing import Callable , Dict , List , Optional
7+ from pathlib import Path
8+ from typing import Any , Callable , Deque , Dict , List , Optional
9+
10+ import yaml
11+
12+
13+ CONFIG_PATH = Path (__file__ ).resolve ().parents [1 ] / "configs" / "azazel.yaml"
614
715
816@dataclass (frozen = True )
@@ -33,18 +41,27 @@ class Transition:
3341
3442@dataclass
3543class StateMachine :
36- """Simple but testable state machine implementation ."""
44+ """Mode-aware state machine with YAML-backed presets ."""
3745
3846 initial_state : State
3947 transitions : List [Transition ] = field (default_factory = list )
48+ config_path : str | Path | None = None
49+ window_size : int = 5
50+ clock : Callable [[], float ] = field (default = time .monotonic , repr = False )
4051 current_state : State = field (init = False )
4152
4253 def __post_init__ (self ) -> None :
4354 self .current_state = self .initial_state
4455 self ._transition_map : Dict [str , List [Transition ]] = {}
4556 for transition in self .transitions :
4657 self .add_transition (transition )
58+ self ._config_cache : Dict [str , Any ] | None = None
59+ self ._score_window : Deque [int ] = deque (maxlen = max (self .window_size , 1 ))
60+ self ._unlock_until : Dict [str , float ] = {}
4761
62+ # ------------------------------------------------------------------
63+ # Transition helpers
64+ # ------------------------------------------------------------------
4865 def add_transition (self , transition : Transition ) -> None :
4966 """Register a new transition."""
5067
@@ -58,6 +75,7 @@ def dispatch(self, event: Event) -> State:
5875 if transition .condition (event ):
5976 previous = self .current_state
6077 self .current_state = transition .target
78+ self ._handle_transition (previous , self .current_state )
6179 if transition .action :
6280 transition .action (previous , self .current_state , event )
6381 return self .current_state
@@ -67,6 +85,8 @@ def reset(self) -> None:
6785 """Reset the state machine to its initial state."""
6886
6987 self .current_state = self .initial_state
88+ self ._score_window .clear ()
89+ self ._unlock_until .clear ()
7090
7191 def summary (self ) -> Dict [str , str ]:
7292 """Return a serializable summary of the state machine."""
@@ -75,3 +95,129 @@ def summary(self) -> Dict[str, str]:
7595 "state" : self .current_state .name ,
7696 "description" : self .current_state .description ,
7797 }
98+
99+ # ------------------------------------------------------------------
100+ # Configuration helpers
101+ # ------------------------------------------------------------------
102+ def _resolve_config_path (self ) -> Path :
103+ if self .config_path is not None :
104+ return Path (self .config_path )
105+ return CONFIG_PATH
106+
107+ def _load_config (self ) -> Dict [str , Any ]:
108+ if self ._config_cache is None :
109+ path = self ._resolve_config_path ()
110+ data = yaml .safe_load (path .read_text ())
111+ if not isinstance (data , dict ):
112+ raise ValueError ("Configuration root must be a mapping" )
113+ self ._config_cache = data
114+ return self ._config_cache
115+
116+ def reload_config (self ) -> None :
117+ """Force re-reading of the YAML configuration."""
118+
119+ self ._config_cache = None
120+
121+ def get_thresholds (self ) -> Dict [str , Any ]:
122+ """Return shield/lockdown thresholds and unlock windows."""
123+
124+ config = self ._load_config ()
125+ thresholds = config .get ("thresholds" , {})
126+ unlock = thresholds .get ("unlock_wait_secs" , {})
127+ return {
128+ "t1" : int (thresholds .get ("t1_shield" , 0 ) or 0 ),
129+ "t2" : int (thresholds .get ("t2_lockdown" , 0 ) or 0 ),
130+ "unlock_wait_secs" : {
131+ "shield" : int (unlock .get ("shield" , 0 ) or 0 ),
132+ "portal" : int (unlock .get ("portal" , 0 ) or 0 ),
133+ },
134+ }
135+
136+ def get_actions_preset (self ) -> Dict [str , Any ]:
137+ """Return the action plan preset for the current mode."""
138+
139+ config = self ._load_config ()
140+ actions = config .get ("actions" , {})
141+ preset = actions .get (self .current_state .name , {})
142+ shape = preset .get ("shape_kbps" )
143+ return {
144+ "delay_ms" : int (preset .get ("delay_ms" , 0 ) or 0 ),
145+ "shape_kbps" : int (shape ) if shape not in (None , "" , False ) else None ,
146+ "block" : bool (preset .get ("block" , False )),
147+ }
148+
149+ # ------------------------------------------------------------------
150+ # Score window evaluation
151+ # ------------------------------------------------------------------
152+ def evaluate_window (self , severity : int ) -> Dict [str , Any ]:
153+ """Append a severity score and compute moving average decisions."""
154+
155+ self ._score_window .append (max (int (severity ), 0 ))
156+ average = sum (self ._score_window ) / len (self ._score_window )
157+ thresholds = self .get_thresholds ()
158+ desired_mode = "portal"
159+ if average >= thresholds ["t2" ]:
160+ desired_mode = "lockdown"
161+ elif average >= thresholds ["t1" ]:
162+ desired_mode = "shield"
163+ return {"average" : average , "desired_mode" : desired_mode }
164+
165+ def apply_score (self , severity : int ) -> Dict [str , Any ]:
166+ """Evaluate the score window and transition to the appropriate mode."""
167+
168+ evaluation = self .evaluate_window (severity )
169+ desired_mode = evaluation ["desired_mode" ]
170+ now = self .clock ()
171+ target_mode = desired_mode
172+ if desired_mode == "portal" :
173+ target_mode = self ._target_for_portal (now )
174+ elif desired_mode == "shield" :
175+ target_mode = self ._target_for_shield (now )
176+
177+ if target_mode != self .current_state .name :
178+ self .dispatch (Event (name = target_mode , severity = severity ))
179+
180+ evaluation .update ({
181+ "target_mode" : target_mode ,
182+ "applied_mode" : self .current_state .name ,
183+ })
184+ return evaluation
185+
186+ # ------------------------------------------------------------------
187+ # Internal helpers
188+ # ------------------------------------------------------------------
189+ def _handle_transition (self , previous : State , current : State ) -> None :
190+ thresholds = self .get_thresholds ()
191+ unlocks = thresholds .get ("unlock_wait_secs" , {})
192+ now = self .clock ()
193+ if current .name == "lockdown" :
194+ wait_shield = unlocks .get ("shield" , 0 )
195+ if wait_shield :
196+ self ._unlock_until ["shield" ] = now + wait_shield
197+ elif current .name == "shield" :
198+ wait_portal = unlocks .get ("portal" , 0 )
199+ if wait_portal :
200+ self ._unlock_until ["portal" ] = now + wait_portal
201+ self ._unlock_until .pop ("shield" , None )
202+ elif current .name == "portal" :
203+ self ._unlock_until .clear ()
204+
205+ def _target_for_shield (self , now : float ) -> str :
206+ if self .current_state .name == "lockdown" :
207+ unlock_at = self ._unlock_until .get ("shield" , 0.0 )
208+ if now < unlock_at :
209+ return "lockdown"
210+ return "shield"
211+
212+ def _target_for_portal (self , now : float ) -> str :
213+ if self .current_state .name == "lockdown" :
214+ unlock_at = self ._unlock_until .get ("shield" , 0.0 )
215+ if now < unlock_at :
216+ return "lockdown"
217+ # Step-down path: lockdown -> shield before portal.
218+ return "shield"
219+ if self .current_state .name == "shield" :
220+ unlock_at = self ._unlock_until .get ("portal" , 0.0 )
221+ if now < unlock_at :
222+ return "shield"
223+ return "portal"
0 commit comments