diff --git a/bluesky_adaptive/adjudicators/README-ADJUDICATORS.md b/bluesky_adaptive/adjudicators/README-ADJUDICATORS.md new file mode 100644 index 0000000..336f99d --- /dev/null +++ b/bluesky_adaptive/adjudicators/README-ADJUDICATORS.md @@ -0,0 +1,42 @@ +# Adjudicators + +The purpose of an adjudicator is to provide another layer of indirection between the agents and the RunEngine Manager. +This is not required, as agents can send plans directly to the queue. +Alternatively, many agents can send plans to an adjudicator that acts as a meta-agent, filtering and deciding which plans from many agents make it to the queue. +In this way, the adjudicator acts as an extra experiment manager. +Feedback is not provided directly to the agents (i.e. no two way communication), so this is in effect, much like how high level human management communicates with low level employees. + +Each adjudicator is required to implement `make_judgments`, which accepts no args or kwargs, and should return a list of tuples that contain the RE manager API, the agent name, and the Suggestion. +These tuples will by validated by Pydantic models, or can be `Judgment` objects. +This enables an agent to suggest many plans at once, to multiple beamlines! +Adjustable properties can be incorperated by the server, allowing for web and caproto control. + +`make_judgments` can be called promptly after every new document, or only on user command. + + +## Use Case: Avoiding redundancy +One challenge of having many agents who can write to the queue is they don't know what other agents are suggesting. This can cause multiple agents to have the same idea about the next experiment, and lead an autonomous experiment to run the same plans redundantly. For example, if I had two Bayesian optimization agents that were minimizing their surrogate model uncertainty, they may have a similar idea for the next best area to measure. +An adjudicator can ensure that only one measurement gets scheduled, but both agents will still recive the data. + +## Use Case: Meta-analysis of many similar agents +You may want to filter down the number of plans comming from multiple agents that are using the same underlying technique. +This mechanism for increasing diversity could be applied to a suite of exploitative optimizers, or maybe complementary decomposition approaches (NMF/PCA/Kmeans) that are suggesting regions near their primary components. +An adjudicator that is conducting analysis of many agents will take careful thought and should be tuned to the set of agents it is attending to. + +## Pydantic Message API Enables multi-experiment, multi-beamline suggestions +```python +suggestion = Suggestion(ask_uid="123", plan_name="test_plan", plan_args=[1, 3], plan_kwargs={"md": {}}) +msg = AdjudicatorMsg( + agent_name="aardvark", + suggestions_uid="456", + suggestions={ + "pdf": [ + suggestion, + suggestion, + ], + "bmm": [ + suggestion, + ], + }, +) +``` diff --git a/bluesky_adaptive/adjudicators/__init__.py b/bluesky_adaptive/adjudicators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bluesky_adaptive/adjudicators/base.py b/bluesky_adaptive/adjudicators/base.py new file mode 100644 index 0000000..33f5d96 --- /dev/null +++ b/bluesky_adaptive/adjudicators/base.py @@ -0,0 +1,251 @@ +import logging +from abc import ABC, abstractmethod +from collections import deque +from copy import deepcopy +from threading import Lock, Thread +from typing import Callable, Sequence, Tuple + +from bluesky_kafka import BlueskyConsumer +from bluesky_queueserver_api import BPlan +from bluesky_queueserver_api.api_threads import API_Threads_Mixin + +from bluesky_adaptive.adjudicators.msg import DEFAULT_NAME, AdjudicatorMsg, Judgment, Suggestion +from bluesky_adaptive.agents.base import Agent as BaseAgent + +logger = logging.getLogger(__name__) + + +class DequeSet: + def __init__(self, maxlen=100): + self._set = set() + self._dequeue = deque() + self._maxlen = maxlen + + def __contains__(self, d): + return d in self._set + + def append(self, d): + if d in self: + logger.debug(f"Attempt to add redundant point to DequeSet ignored: {d}") + return + self._set.add(d) + self._dequeue.append(d) + while len(self._dequeue) >= self._maxlen: + discarded = self._dequeue.popleft() + self._set.remove(discarded) + + +class AdjudicatorBase(BlueskyConsumer, ABC): + """ + An agent adjudicator that listens to published suggestions by agents. + This Base approach (as per `process_document`) only retains the most recent suggestions by any named agents. + Other mechanisms for tracking can be provided as in example sub-classes. + + Parameters + ---------- + topics : list of str + List of existing_topics as strings such as ["topic-1", "topic-2"] + bootstrap_servers : str + Comma-delimited list of Kafka server addresses as a string + such as ``'broker1:9092,broker2:9092,127.0.0.1:9092'`` + group_id : str + Required string identifier for the consumer's Kafka Consumer group. + """ + + _register_method = BaseAgent._register_method + _register_property = BaseAgent._register_property + + def __init__(self, topics: list[str], bootstrap_servers: str, group_id: str, *args, **kwargs): + super().__init__(topics, bootstrap_servers, group_id, *args, **kwargs) + self._lock = Lock() + self._thread = None + self._current_suggestions = {} # agent_name: AdjudicatorMsg + self._ask_uids = DequeSet() + self._prompt = True + + try: + self.server_registrations() + except RuntimeError as e: + logger.warning(f"Agent server unable to make registrations. Continuing regardless of\n {e}") + + def start(self, *args, **kwargs): + self._thread = Thread( + target=BlueskyConsumer.start, + name="adjudicator-loop", + daemon=True, + args=[self] + list(args), + kwargs=kwargs, + ) + self._thread.start() + + def process_document(self, topic, name, doc): + if name != DEFAULT_NAME: + return True + with self._lock: + logger.info(f"{doc['agent_name']=}, {doc['suggestions_uid']=}") + self._current_suggestions[doc["agent_name"]] = AdjudicatorMsg(**doc) + + if self.prompt_judgment: + self._make_judgments_and_add_to_queue() + + @property + def current_suggestions(self): + """Dictionary of {agent_name:AdjudicatorMsg}, deep copied at each grasp.""" + with self._lock: + ret = deepcopy(self._current_suggestions) + return ret + + @property + def agent_names(self): + with self._lock: + ret = list(self._current_suggestions.keys()) + return ret + + @property + def prompt_judgment(self) -> bool: + return self._prompt + + @prompt_judgment.setter + def prompt_judgment(self, flag: bool): + self._prompt = flag + + def _add_suggestion_to_queue(self, re_manager: API_Threads_Mixin, agent_name: str, suggestion: Suggestion): + if suggestion.ask_uid in self._ask_uids: + logger.debug(f"Ask uid {suggestion.ask_uid} has already been seen. Not adding anything to the queue.") + return + else: + self._ask_uids.append(suggestion.ask_uid) + kwargs = suggestion.plan_kwargs + kwargs.setdefault("md", {}) + kwargs["md"]["agent_ask_uid"] = suggestion.ask_uid + kwargs["md"]["agent_name"] = agent_name + plan = BPlan(suggestion.plan_name, *suggestion.plan_args, **kwargs) + r = re_manager.item_add(plan, pos="back") + logger.debug(f"Sent http-server request by adjudicator\n." f"Received reponse: {r}") + + def server_registrations(self) -> None: + """ + Method to generate all server registrations during agent initialization. + This method can be used in subclasses, to override or extend the default registrations. + """ + self._register_method("make_judgements", "_make_judgments_and_add_to_queue") + self._register_property("prompt_judgment") + self._register_property("current_suggestions") + + def _make_judgments_and_add_to_queue(self): + """Internal wrapper for making judgements, validating, and adding to queue.""" + judgments = self.make_judgments() + for judgment in judgments: + if not isinstance(judgment, Judgment): + judgment = Judgment(*judgment) # Validate + self._add_suggestion_to_queue(judgment.re_manager, judgment.agent_name, judgment.suggestion) + + @abstractmethod + def make_judgments(self) -> Sequence[Tuple[API_Threads_Mixin, str, Suggestion]]: + """Instance method to make judgements based on current suggestions. + The returned tuples will be deconstructed to add suggestions to the queue. + """ + ... + + +class AgentByNameAdjudicator(AdjudicatorBase): + """Adjudicator that only allows messages from a set primary agent, and uses a single qserver. + Parameters + ---------- + qservers : dict[str, API_Threads_Mixin] + Dictionary of objects to manage communication with Queue Server. These should be keyed by the beamline TLA + expected in AdjudicatorMsg.suggestions dictionary. + """ + + def __init__(self, *args, qservers: dict[str, API_Threads_Mixin], **kwargs): + self._primary_agent = "" + self._re_managers = qservers + super().__init__(*args, **kwargs) + + @property + def primary_agent(self): + return self._primary_agent + + @primary_agent.setter + def primary_agent(self, name: str): + self._primary_agent = name + + def server_registrations(self) -> None: + self._register_property("priamry_agent") + super().server_registrations() + + def make_judgments(self) -> Sequence[Tuple[API_Threads_Mixin, str, Suggestion]]: + judgments = [] + + if self.primary_agent not in self.agent_names: + logger.debug(f"Agent {self.primary_agent} not known to the Adjudicator") + else: + adjudicator_msg = self.current_suggestions[self.primary_agent] + for key, manager in self._re_managers.items(): + suggestions = adjudicator_msg.suggestions.get(key, []) + for suggestion in suggestions: + judgments.append( + Judgment(re_manager=manager, agent_name=self.primary_agent, suggestion=suggestion) + ) + return judgments + + +class NonredundantAdjudicator(AdjudicatorBase): + """Use a hashing function to convert any suggestion into a unique hash. + + Parameters + ---------- + topics : list of str + List of existing_topics as strings such as ["topic-1", "topic-2"] + bootstrap_servers : str + Comma-delimited list of Kafka server addresses as a string + such as ``'broker1:9092,broker2:9092,127.0.0.1:9092'`` + group_id : str + Required string identifier for the consumer's Kafka Consumer group. + qservers : dict[str, API_Threads_Mixin] + Dictionary of objects to manage communication with Queue Server. These should be keyed by the beamline TLA + expected in AdjudicatorMsg.suggestions dictionary. + hash_suggestion : Callable + Function that takes the tla and Suggestion object, and returns a hashable object as :: + + def hash_suggestion(tla: str, suggestion: Suggestion) -> Hashable: ... + + + This hashable object will be used to check redundancy in a set. + + Examples + -------- + >>> def hash_suggestion(tla: str, suggestion: Suggestion): + >>> # Uses only the tla, plan name, and args to define redundancy, avoiding any details in kwargs + >>> return f"{tla} {suggestion.plan_name} {str(suggestion.plan_args)}" + """ + + def __init__( + self, + topics: list[str], + bootstrap_servers: str, + group_id: str, + *args, + qservers: dict[str, API_Threads_Mixin], + hash_suggestion: Callable, + **kwargs, + ): + super().__init__(topics, bootstrap_servers, group_id, *args, **kwargs) + self.hash_suggestion = hash_suggestion + self.suggestion_set = set() + self._re_managers = qservers + + def make_judgments(self) -> Sequence[Tuple[API_Threads_Mixin, str, Suggestion]]: + """Loop over all recieved adjudicator mesages, and their suggested plans by beamline, + seeking redundancy.""" + passing_judgements = [] + for agent_name, adjudicator_msg in self.current_suggestions.items(): + for tla, suggestions in adjudicator_msg.suggestions.items(): + for suggestion in suggestions: + hashable = self.hash_suggestion(tla, suggestion) + if hashable in self.suggestion_set: + continue + else: + passing_judgements.append(Judgment(self._re_managers[tla], agent_name, suggestion)) + self.suggestion_set.add(hashable) + return passing_judgements diff --git a/bluesky_adaptive/adjudicators/msg.py b/bluesky_adaptive/adjudicators/msg.py new file mode 100644 index 0000000..536f3ef --- /dev/null +++ b/bluesky_adaptive/adjudicators/msg.py @@ -0,0 +1,57 @@ +from typing import AnyStr, Dict, List + +from bluesky_queueserver_api.api_threads import API_Threads_Mixin +from pydantic import BaseModel + +DEFAULT_NAME = "agent_suggestions" + + +class Suggestion(BaseModel): + ask_uid: str # UID from the agent ask message + plan_name: str + plan_args: list = [] + plan_kwargs: dict = {} + + +class AdjudicatorMsg(BaseModel): + agent_name: str + suggestions_uid: str + suggestions: Dict[AnyStr, List[Suggestion]] # TLA: list + + +class Judgment(BaseModel): + """Allow for positional arguments from user derived make judgements""" + + re_manager: API_Threads_Mixin + agent_name: str + suggestion: Suggestion + + class Config: + arbitrary_types_allowed = True + + def __init__(self, re_manager: API_Threads_Mixin, agent_name: str, suggestion: Suggestion, **kwargs) -> None: + super().__init__(re_manager=re_manager, agent_name=agent_name, suggestion=suggestion, **kwargs) + + +if __name__ == "__main__": + """Example main to show serializing capabilities""" + import msgpack + + suggestion = Suggestion(ask_uid="123", plan_name="test_plan", plan_args=[1, 3], plan_kwargs={"md": {}}) + msg = AdjudicatorMsg( + agent_name="aardvark", + suggestions_uid="456", + suggestions={ + "pdf": [ + suggestion, + suggestion, + ], + "bmm": [ + suggestion, + ], + }, + ) + print(msg) + s = msgpack.dumps(msg.dict()) + new_msg = AdjudicatorMsg(**msgpack.loads(s)) + print(new_msg) diff --git a/bluesky_adaptive/agents/base.py b/bluesky_adaptive/agents/base.py index 401d14b..98983d6 100644 --- a/bluesky_adaptive/agents/base.py +++ b/bluesky_adaptive/agents/base.py @@ -20,7 +20,9 @@ from numpy.typing import ArrayLike from xkcdpass import xkcd_password as xp -from bluesky_adaptive.server import register_variable, start_task +from ..adjudicators.msg import DEFAULT_NAME as ADJUDICATOR_STREAM_NAME +from ..adjudicators.msg import AdjudicatorMsg, Suggestion +from ..server import register_variable, start_task logger = getLogger("bluesky_adaptive.agents") PASSWORD_LIST = xp.generate_wordlist(wordfile=xp.locate_wordfile(), min_length=3, max_length=6) @@ -202,8 +204,8 @@ class Agent(ABC): Bluesky stop documents that will trigger ``tell``. AgentConsumer is a child class of bluesky_kafka.RemoteDispatcher that enables kafka messages to trigger agent directives. - kafka_producer : Publisher - Bluesky Kafka publisher to produce document stream of agent actions. + kafka_producer : Optional[Publisher] + Bluesky Kafka publisher to produce document stream of agent actions for optional Adjudicator. tiled_data_node : tiled.client.node.Node Tiled node to serve as source of data (BlueskyRuns) for the agent. tiled_agent_node : tiled.client.node.Node @@ -234,16 +236,20 @@ class Agent(ABC): Default kwargs for calling the ``report`` method, by default None queue_add_position : Optional[Union[int, Literal["front", "back"]]], optional Starting postion to add to the queue if adding directly to the queue, by default "back". + endstation_key : Optional[str] + Optional string that is needed for Adjudicator functionality. This keys the qserver API instance to + a particular endstation. This way child Agents can maintain multiple queues for different unit operations. + For example, this could be a beamline three letter acronym or other distinct key. """ def __init__( self, *, kafka_consumer: AgentConsumer, - kafka_producer: Publisher, tiled_data_node: tiled.client.node.Node, tiled_agent_node: tiled.client.node.Node, qserver: API_Threads_Mixin, + kafka_producer: Optional[Publisher], agent_run_suffix: Optional[str] = None, metadata: Optional[dict] = None, ask_on_tell: Optional[bool] = True, @@ -251,6 +257,7 @@ def __init__( report_on_tell: Optional[bool] = False, default_report_kwargs: Optional[dict] = None, queue_add_position: Optional[Union[int, Literal["front", "back"]]] = None, + endstation_key: Optional[str] = "", ): logger.debug("Initializing agent.") self.kafka_consumer = kafka_consumer @@ -281,6 +288,7 @@ def __init__( self._compose_run_bundle = None self._compose_descriptor_bundles = dict() self.re_manager = qserver + self.endstation_key = endstation_key self._queue_add_position = "back" if queue_add_position is None else queue_add_position self._direct_to_queue = direct_to_queue self.default_plan_md = dict(agent_name=self.instance_name, agent_class=str(type(self))) @@ -565,57 +573,80 @@ def _check_queue_and_start(self): self.re_manager.queue_start() logger.info("Agent is starting an idle queue with exactly 1 item.") - def add_suggestions_to_queue(self, batch_size: int): - """Calls ask, adds suggestions to queue, and writes out events. - This will create one event for each suggestion. + def _ask_and_write_events( + self, batch_size: int, ask_method: Optional[Callable] = None, stream_name: Optional[str] = "ask" + ): + """Private ask method for consistency across calls and changes to docs streams. + + Parameters + ---------- + batch_size : int + Size of batch passed to ask + ask_method : Optional[Callable] + self.ask, or self.subject_ask, or some target ask function. + Defaults to self.ask + stream_name : Optional[str] + Name for ask stream corresponding to `ask_method`. 'ask', 'subject_ask', or other. + Defaults to 'ask' + + Returns + ------- + next_points : list + Next points to be sent to adjudicator or queue + uid : str """ - docs, next_points = self.ask(batch_size) + if ask_method is None: + ask_method = self.ask + docs, next_points = ask_method(batch_size) uid = str(uuid.uuid4()) for batch_idx, (doc, next_point) in enumerate(zip(docs, next_points)): doc["suggestion"] = next_point doc["batch_idx"] = batch_idx doc["batch_size"] = len(next_points) - self._write_event("ask", doc, uid=f"{uid}/{batch_idx}") + self._write_event(stream_name, doc, uid=f"{uid}/{batch_idx}") + return next_points, uid + + def add_suggestions_to_queue(self, batch_size: int): + """Calls ask, adds suggestions to queue, and writes out events. + This will create one event for each suggestion. + """ + next_points, uid = self._ask_and_write_events(batch_size) logger.info(f"Issued ask and adding to the queue. {uid}") self._add_to_queue(next_points, uid) self._check_queue_and_start() # TODO: remove this and encourage updated qserver functionality - def _create_suggestion_list(self, points, uid): + def _create_suggestion_list(self, points: Sequence, uid: str, measurement_plan: Optional[Callable] = None): """Create suggestions for adjudicator""" - raise NotImplementedError - """Not implementing yet to lighten PR load. Copied is implementation from MMM. suggestions = [] for point in points: - kwargs = self.measurement_plan_kwargs(point) + plan_name, args, kwargs = ( + self.measurement_plan(point) if measurement_plan is None else measurement_plan(point) + ) kwargs.setdefault("md", {}) kwargs["md"].update(self.default_plan_md) kwargs["md"]["agent_ask_uid"] = uid - args = self.measurement_plan_args(point) suggestions.append( Suggestion( ask_uid=uid, - plan_name=self.measurement_plan_name, + plan_name=plan_name, plan_args=args, plan_kwargs=kwargs, ) ) return suggestions - """ def generate_suggestions_for_adjudicator(self, batch_size: int): - raise NotImplementedError - """ Not implementing yet to lighten PR load. Copied is implementation from MMM. - doc, next_points = self.ask(batch_size) - uid = self._write_event("ask", doc) - logger.info(f"Issued ask and sending to adjudicator. {uid}") + """Calls ask, sends suggestions to adjudicator, and writes out events. + This will create one event for each suggestion.""" + next_points, uid = self._ask_and_write_events(batch_size) + logger.info(f"Issued ask and sending to the adjudicator. {uid}") suggestions = self._create_suggestion_list(next_points, uid) msg = AdjudicatorMsg( - agent_name=self.agent_name, + agent_name=self.instance_name, suggestions_uid=str(uuid.uuid4()), - suggestions={self.beamline_tla: suggestions}, + suggestions={self.endstation_key: suggestions}, ) self.kafka_producer(ADJUDICATOR_STREAM_NAME, msg.dict()) - """ def generate_report(self, **kwargs): doc = self.report(**kwargs) @@ -911,7 +942,14 @@ def from_config_kwargs( class MonarchSubjectAgent(Agent, ABC): # Drive a beamline. On stop doc check. By default manual trigger. - def __init__(self, *args, subject_qserver: API_Threads_Mixin, **kwargs): + def __init__( + self, + *args, + subject_qserver: API_Threads_Mixin, + subject_kafka_producer: Optional[Publisher] = None, + subject_endstation_key: Optional[str] = "", + **kwargs, + ): """Abstract base class for a MonarchSubject agent. These agents only consume documents from one (Monarch) source, and can dictate the behavior of a different (Subject) queue. This can be useful in a multimodal measurement where @@ -940,9 +978,17 @@ def __init__(self, *args, subject_qserver: API_Threads_Mixin, **kwargs): ---------- subject_qserver : API_Threads_Mixin Object to manage communication with the Subject Queue Server + subject_kafka_producer : Optional[Publisher] + Bluesky Kafka publisher to produce document stream of agent actions to Adjudicators + subject_endstation_key : Optional[str] + Optional string that is needed for Adjudicator functionality. This keys the qserver API instance to + a particular endstation. This way child Agents can maintain multiple queues for different unit ops. + For example, this could be a beamline three letter acronym or other distinct key. """ super().__init__(**kwargs) self.subject_re_manager = subject_qserver + self.subject_kafka_producer = subject_kafka_producer + self.subject_endstation_key = subject_endstation_key @abstractmethod def subject_measurement_plan(self, point: ArrayLike) -> Tuple[str, List, dict]: @@ -1000,13 +1046,7 @@ def subject_ask_condition(self): def add_suggestions_to_subject_queue(self, batch_size: int): """Calls ask, adds suggestions to queue, and writes out event""" - docs, next_points = self.subject_ask(batch_size) - uid = str(uuid.uuid4()) - for batch_idx, (doc, next_point) in enumerate(zip(docs, next_points)): - doc["suggestion"] = next_point - doc["batch_idx"] = batch_idx - doc["batch_size"] = len(next_points) - self._write_event("subject_ask", doc) + next_points, uid = self._ask_and_write_events(batch_size, self.subject_ask, "subject_ask") logger.info("Issued ask to subject and adding to the queue. {uid}") self._add_to_queue(next_points, uid, re_manager=self.subject_re_manager, position="front") @@ -1021,6 +1061,17 @@ def _on_stop_router(self, name, doc): else: raise NotImplementedError + def generate_suggestions_for_adjudicator(self, batch_size: int): + next_points, uid = self._ask_and_write_events(batch_size, self.subject_ask, "subject_ask") + logger.info(f"Issued subject ask and sending to the adjudicator. {uid}") + suggestions = self._create_suggestion_list(next_points, uid, self.subject_measurement_plan) + msg = AdjudicatorMsg( + agent_name=self.instance_name, + suggestions_uid=str(uuid.uuid4()), + suggestions={self.subject_endstation_key: suggestions}, + ) + self.subject_kafka_producer(ADJUDICATOR_STREAM_NAME, msg.dict()) + def server_registrations(self) -> None: super().server_registrations() self._register_method("add_suggestions_to_subject_queue") diff --git a/bluesky_adaptive/server/demo/adjudicator_sandbox.py b/bluesky_adaptive/server/demo/adjudicator_sandbox.py new file mode 100644 index 0000000..e2d5d68 --- /dev/null +++ b/bluesky_adaptive/server/demo/adjudicator_sandbox.py @@ -0,0 +1,72 @@ +# BS_AGENT_STARTUP_SCRIPT_PATH=./bluesky_adaptive/server/demo/adjudicator_sandbox.py \ +# uvicorn bluesky_adaptive.server:app +from bluesky_kafka.utils import create_topics, delete_topics +from bluesky_queueserver_api.http import REManagerAPI + +from bluesky_adaptive.adjudicators.base import NonredundantAdjudicator +from bluesky_adaptive.adjudicators.msg import Suggestion +from bluesky_adaptive.server import shutdown_decorator, startup_decorator + +broker_authorization_config = { + "acks": 1, + "enable.idempotence": False, + "request.timeout.ms": 1000, + "bootstrap.servers": "127.0.0.1:9092", +} +tiled_profile = "testing_sandbox" +kafka_bootstrap_servers = "127.0.0.1:9092" +bootstrap_servers = kafka_bootstrap_servers +admin_client_config = broker_authorization_config +topics = ["test.publisher", "test.subscriber"] +adj_topic, sub_topic = topics + + +re_manager = REManagerAPI(http_server_uri=None) +re_manager.set_authorization_key(api_key="SECRET") + + +def _hash_suggestion(tla, suggestion: Suggestion): + return f"{tla} {suggestion.plan_name} {str(suggestion.plan_args)}" + + +adjudicator = NonredundantAdjudicator( + topics=[adj_topic], + bootstrap_servers=kafka_bootstrap_servers, + group_id="test.communication.group", + qservers={"tst": re_manager}, + consumer_config={"auto.offset.reset": "earliest"}, + hash_suggestion=_hash_suggestion, +) + + +@startup_decorator +def startup_topics(): + delete_topics( + bootstrap_servers=bootstrap_servers, + topics_to_delete=topics, + admin_client_config=admin_client_config, + ) + create_topics( + bootstrap_servers=bootstrap_servers, + topics_to_create=topics, + admin_client_config=admin_client_config, + ) + + +@startup_decorator +def startup_adjudicator(): + adjudicator.start() + + +@shutdown_decorator +def shutdown_agent(): + return adjudicator.stop() + + +@shutdown_decorator +def shutdown_topics(): + delete_topics( + bootstrap_servers=bootstrap_servers, + topics_to_delete=topics, + admin_client_config=admin_client_config, + ) diff --git a/bluesky_adaptive/tests/test_adjudicators.py b/bluesky_adaptive/tests/test_adjudicators.py new file mode 100644 index 0000000..6a476dd --- /dev/null +++ b/bluesky_adaptive/tests/test_adjudicators.py @@ -0,0 +1,306 @@ +import time as ttime +from typing import Sequence, Tuple, Union + +from bluesky_kafka import BlueskyConsumer, Publisher +from bluesky_queueserver_api.http import REManagerAPI +from databroker.client import BlueskyRun +from event_model import compose_run +from numpy.typing import ArrayLike + +from bluesky_adaptive.adjudicators.base import AgentByNameAdjudicator, NonredundantAdjudicator +from bluesky_adaptive.adjudicators.msg import Suggestion +from bluesky_adaptive.agents.base import Agent, AgentConsumer + +KAFKA_TIMEOUT = 30.0 # seconds + + +class NoTiled: + class V1: + def insert(self, *args, **kwargs): + pass + + v1 = V1 + + +class TestAgent(Agent): + measurement_plan_name = "agent_driven_nap" + + def __init__(self, pub_topic, sub_topic, kafka_bootstrap_servers, broker_authorization_config, qs, **kwargs): + kafka_consumer = AgentConsumer( + topics=[sub_topic], + bootstrap_servers=kafka_bootstrap_servers, + group_id="test.communication.group", + consumer_config={"auto.offset.reset": "latest"}, + ) + kafka_producer = Publisher( + topic=pub_topic, + bootstrap_servers=kafka_bootstrap_servers, + key="", + producer_config=broker_authorization_config, + ) + super().__init__( + kafka_consumer=kafka_consumer, + kafka_producer=kafka_producer, + tiled_agent_node=None, + tiled_data_node=None, + qserver=qs, + **kwargs, + ) + self.count = 0 + self.agent_catalog = NoTiled() + + def no_tiled(*args, **kwargs): + pass + + def measurement_plan(self, point: ArrayLike) -> Tuple[str, list, dict]: + return self.measurement_plan_name, [1.5], dict() + + def unpack_run(self, run: BlueskyRun) -> Tuple[Union[float, ArrayLike], Union[float, ArrayLike]]: + return 0, 0 + + def report(self, report_number: int = 0) -> dict: + return dict(agent_name=self.instance_name, report=f"report_{report_number}") + + def ask(self, batch_size: int = 1) -> Tuple[dict, Sequence]: + return ([dict(agent_name=self.instance_name, report=f"ask_{batch_size}")], [0 for _ in range(batch_size)]) + + def tell(self, x, y) -> dict: + self.count += 1 + return dict(x=x, y=y) + + def start(self): + """Start without kafka consumer start""" + self._compose_run_bundle = compose_run(metadata=self.metadata) + self.agent_catalog.v1.insert("start", self._compose_run_bundle.start_doc) + + +class AccumulateAdjudicator(AgentByNameAdjudicator): + def __init__(self, *args, qservers, **kwargs): + super().__init__(*args, qservers=qservers, **kwargs) + self.consumed_documents = [] + + def process_document(self, topic, name, doc): + self.consumed_documents.append((name, doc)) + return super().process_document(topic, name, doc) + + def until_len(self): + if len(self.consumed_documents) >= 1: + return False + else: + return True + + +def test_accumulate(temporary_topics, kafka_bootstrap_servers, broker_authorization_config): + # Smoke test for the kafka comms and acumulation with `continue_polling` function + with temporary_topics(topics=["test.adjudicator"]) as (topic,): + publisher = Publisher( + topic=topic, + bootstrap_servers=kafka_bootstrap_servers, + producer_config=broker_authorization_config, + key=f"{topic}.key", + ) + re_manager = REManagerAPI(http_server_uri=None) + re_manager.set_authorization_key(api_key="SECRET") + adjudicator = AccumulateAdjudicator( + topics=[topic], + bootstrap_servers=kafka_bootstrap_servers, + group_id="test.communication.group", + qservers={"tst": re_manager}, + consumer_config={"auto.offset.reset": "earliest"}, + ) + adjudicator.start(continue_polling=adjudicator.until_len) + publisher("name", {"dfi": "Dfs"}) + publisher("name", {"dfi": "Dfs"}) + start_time = ttime.monotonic() + while adjudicator.until_len(): + ttime.sleep(0.5) + if ttime.monotonic() - start_time > KAFKA_TIMEOUT: + break + assert len(adjudicator.consumed_documents) == 1 + + +def test_send_to_adjudicator(temporary_topics, kafka_bootstrap_servers, broker_authorization_config): + def consume_until_len(kafka_topic, length): + consumed_documents = [] + start_time = ttime.monotonic() + + def process_document(consumer, topic, name, document): + consumed_documents.append((name, document)) + + consumer = BlueskyConsumer( + topics=[kafka_topic], + bootstrap_servers=kafka_bootstrap_servers, + group_id=f"{kafka_topic}.consumer.group", + consumer_config={"auto.offset.reset": "earliest"}, + process_document=process_document, + ) + + def until_len(): + if len(consumed_documents) >= length: + return False + elif ttime.monotonic() - start_time > KAFKA_TIMEOUT: + raise TimeoutError("Kafka Timeout in test environment") + else: + return True + + consumer.start(continue_polling=until_len) + return consumed_documents + + # Test the internal publisher + with temporary_topics(topics=["test.adjudicator", "test.data"]) as (adj_topic, bs_topic): + agent = TestAgent(adj_topic, bs_topic, kafka_bootstrap_servers, broker_authorization_config, None) + agent.kafka_producer("test", {"some": "dict"}) + cache = consume_until_len(kafka_topic=adj_topic, length=1) + assert len(cache) == 1 + + # Test agent sending to adjudicator + with temporary_topics(topics=["test.adjudicator", "test.data"]) as (adj_topic, bs_topic): + agent = TestAgent(adj_topic, bs_topic, kafka_bootstrap_servers, broker_authorization_config, None) + agent.start() + agent.generate_suggestions_for_adjudicator(1) + cache = consume_until_len(kafka_topic=adj_topic, length=1) + assert len(cache) == 1 + + +def test_adjudicator_receipt(temporary_topics, kafka_bootstrap_servers, broker_authorization_config): + # Test agent sending to adjudicator + with temporary_topics(topics=["test.adjudicator", "test.data"]) as (adj_topic, bs_topic): + agent = TestAgent(adj_topic, bs_topic, kafka_bootstrap_servers, broker_authorization_config, None) + agent.start() + adjudicator = AccumulateAdjudicator( + topics=[adj_topic], + bootstrap_servers=kafka_bootstrap_servers, + group_id="test.communication.group", + qservers={"tst": None}, + consumer_config={"auto.offset.reset": "earliest"}, + ) + adjudicator.start(continue_polling=adjudicator.until_len) + agent.generate_suggestions_for_adjudicator(1) + start_time = ttime.monotonic() + while adjudicator.until_len(): + ttime.sleep(0.5) + if ttime.monotonic() - start_time > KAFKA_TIMEOUT: + raise TimeoutError("Adjudicator did not accumulate suggestions") + assert len(adjudicator.consumed_documents) == 1 + + +def test_adjudicator_by_name(temporary_topics, kafka_bootstrap_servers, broker_authorization_config): + with temporary_topics(topics=["test.adjudicator", "test.data"]) as (adj_topic, bs_topic): + re_manager = REManagerAPI(http_server_uri=None) + re_manager.set_authorization_key(api_key="SECRET") + adjudicator = AccumulateAdjudicator( + topics=[adj_topic], + bootstrap_servers=kafka_bootstrap_servers, + group_id="test.communication.group", + qservers={"tst": re_manager}, + consumer_config={"auto.offset.reset": "earliest"}, + ) + adjudicator.primary_agent = "good" + adjudicator.prompt_judgment = False + adjudicator.start() + + good_agent = TestAgent( + adj_topic, + bs_topic, + kafka_bootstrap_servers, + broker_authorization_config, + re_manager, + endstation_key="tst", + ) + good_agent.instance_name = "good" + good_agent.start() + evil_agent = TestAgent( + adj_topic, + bs_topic, + kafka_bootstrap_servers, + broker_authorization_config, + re_manager, + endstation_key="tst", + ) + evil_agent.instance_name = "evil" + evil_agent.start() + + re_manager = good_agent.re_manager + status = re_manager.status() + if not status["worker_environment_exists"]: + re_manager.environment_open() + re_manager.queue_clear() + + # Make sure we can put something on the queue from the adjudicator + adjudicator._add_suggestion_to_queue( + re_manager, + "good", + Suggestion(ask_uid="test", plan_name="agent_driven_nap", plan_args=[1.5], plan_kwargs={}), + ) + assert re_manager.status()["items_in_queue"] == 1 + + # Make sure suggestions are making it to adjudicator + good_agent.generate_suggestions_for_adjudicator(1) + start_time = ttime.monotonic() + while not adjudicator.consumed_documents: + ttime.sleep(0.1) + if ttime.monotonic() - start_time > KAFKA_TIMEOUT: + raise TimeoutError("Adjudicator did not accumulate suggestions") + assert adjudicator.current_suggestions + assert "good" in adjudicator.agent_names + + # Make sure adjudicator can throw the right suggestions onto the queue + good_agent.generate_suggestions_for_adjudicator(1) + evil_agent.generate_suggestions_for_adjudicator(1) + start_time = ttime.monotonic() + while len(adjudicator.current_suggestions) < 2: + ttime.sleep(0.1) + if ttime.monotonic() - start_time > KAFKA_TIMEOUT: + raise TimeoutError("Adjudicator did not accumulate suggestions") + judgments = adjudicator.make_judgments() + assert adjudicator.primary_agent in adjudicator.current_suggestions.keys() + assert len(judgments) == 1 + assert judgments[0].agent_name == "good" + assert judgments[0].re_manager == re_manager + + +def test_nonredundant_adjudicator(temporary_topics, kafka_bootstrap_servers, broker_authorization_config): + def _hash_suggestion(tla, suggestion: Suggestion): + return f"{tla} {suggestion.plan_name} {str(suggestion.plan_args)}" + + with temporary_topics(topics=["test.adjudicator", "test.data"]) as (adj_topic, bs_topic): + re_manager = REManagerAPI(http_server_uri=None) + re_manager.set_authorization_key(api_key="SECRET") + adjudicator = NonredundantAdjudicator( + topics=[adj_topic], + bootstrap_servers=kafka_bootstrap_servers, + group_id="test.communication.group", + qservers={"tst": re_manager}, + consumer_config={"auto.offset.reset": "earliest"}, + hash_suggestion=_hash_suggestion, + ) + adjudicator.prompt_judgment = False + adjudicator.start() + agent = TestAgent( + adj_topic, + bs_topic, + kafka_bootstrap_servers, + broker_authorization_config, + re_manager, + endstation_key="tst", + ) + agent.start() + # Assure 5 suggestions that are the same only land as 1 judgement + agent.generate_suggestions_for_adjudicator(5) + start_time = ttime.monotonic() + while not adjudicator.current_suggestions: + ttime.sleep(0.1) + if ttime.monotonic() - start_time > KAFKA_TIMEOUT: + raise TimeoutError("Adjudicator did not accumulate suggestions") + judgments = adjudicator.make_judgments() + assert len(judgments) == 1 + + # Assure that additional suggestions don't pass judgement + agent.generate_suggestions_for_adjudicator(1) + start_time = ttime.monotonic() + while not adjudicator.current_suggestions: + ttime.sleep(0.1) + if ttime.monotonic() - start_time > KAFKA_TIMEOUT: + raise TimeoutError("Adjudicator did not accumulate suggestions") + judgments = adjudicator.make_judgments() + assert len(judgments) == 0 diff --git a/bluesky_adaptive/tests/test_agents.py b/bluesky_adaptive/tests/test_agents.py index e3f8220..b57d543 100644 --- a/bluesky_adaptive/tests/test_agents.py +++ b/bluesky_adaptive/tests/test_agents.py @@ -59,7 +59,7 @@ def report(self, report_number: int = 0) -> dict: return dict(agent_name=self.instance_name, report=f"report_{report_number}") def ask(self, batch_size: int = 1) -> Tuple[dict, Sequence]: - return (dict(agent_name=self.instance_name, report=f"ask_{batch_size}"), [0 for _ in range(batch_size)]) + return ([dict(agent_name=self.instance_name, report=f"ask_{batch_size}")], [0 for _ in range(batch_size)]) def tell(self, x, y) -> dict: self.count += 1 @@ -114,8 +114,8 @@ def test_agent_doc_stream(temporary_topics, kafka_bootstrap_servers, broker_auth pub, sub, kafka_bootstrap_servers, broker_authorization_config, tiled_profile ) agent.start() - doc, _ = agent.ask(1) - ask_uid = agent._write_event("ask", doc) + docs, _ = agent.ask(1) + ask_uid = agent._write_event("ask", docs[0]) doc = agent.tell(0, 0) _ = agent._write_event("tell", doc) doc = agent.report() @@ -128,7 +128,7 @@ def test_agent_doc_stream(temporary_topics, kafka_bootstrap_servers, broker_auth assert "report" in cat[-1].metadata["summary"]["stream_names"] assert "ask" in cat[-1].metadata["summary"]["stream_names"] assert "tell" in cat[-1].metadata["summary"]["stream_names"] - assert isinstance(ask_uid, list) + assert isinstance(ask_uid, str) def test_feedback_to_queue( @@ -353,7 +353,7 @@ def subject_measurement_plan(self, point: ArrayLike): return "agent_driven_nap", [0.7], dict() def subject_ask(self, batch_size: int): - return dict(), [0.0 for _ in range(batch_size)] + return [dict()], [0.0 for _ in range(batch_size)] def unpack_run(self, run: BlueskyRun): return 0, 0 @@ -376,8 +376,8 @@ def test_monarch_subject(temporary_topics, kafka_bootstrap_servers, broker_autho agent.start() while True: # Awaiting the agent build before artificial ask - if agent.builder is not None: - if agent.builder._cache.start_doc["uid"] in agent.agent_catalog: + if agent._compose_run_bundle is not None: + if agent._compose_run_bundle.start_doc["uid"] in agent.agent_catalog: break else: continue diff --git a/requirements.txt b/requirements.txt index 0467977..9befd33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ bluesky-queueserver-api xkcdpass tiled numpy +pydantic