Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): add batch almanac api and contract registrations for Bureau #551

Merged
merged 18 commits into from
Nov 8, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: add ledger batch registrations
jrriehl committed Oct 14, 2024
commit fd41b955031d48163759ace5782ed96645e50286
12 changes: 9 additions & 3 deletions python/examples/05-send-msg/main.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,10 @@ class Message(Model):
text: str


alice = Agent(name="alice", seed="alice recovery phrase")
bob = Agent(name="bob", seed="bob recovery phrase")
alice = Agent(
name="alice", seed="alice recovery phrase", agentverse="http://localhost:8001"
)
bob = Agent(name="bob", seed="bob recovery phrase", agentverse="http://localhost:8001")


@alice.on_interval(period=2.0)
@@ -20,7 +22,11 @@ async def message_handler(ctx: Context, sender: str, msg: Message):
ctx.logger.info(f"Received message from {sender}: {msg.text}")


bureau = Bureau()
bureau = Bureau(
endpoint="http://localhost:8000/submit",
log_level="DEBUG",
agentverse="http://localhost:8001",
)
bureau.add(alice)
bureau.add(bob)

94 changes: 67 additions & 27 deletions python/src/uagents/network.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

import asyncio
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from cosmpy.aerial.client import (
DEFAULT_QUERY_INTERVAL_SECS,
@@ -199,68 +199,108 @@ def is_registered(self, address: str) -> bool:

return bool(response.get("record"))

def get_expiry(self, address: str) -> int:
def registration_needs_update(
self,
address: str,
endpoints: List[AgentEndpoint],
protocols: List[str],
min_seconds_left: int,
) -> bool:
"""
Check if an agent's registration needs to be updated.

Args:
address (str): The agent's address.
endpoints (List[AgentEndpoint]): The agent's endpoints.
protocols (List[str]): The agent's protocols.
min_time_left (int): The minimum time left before the agent's registration expires

Returns:
bool: True if the agent's registration needs to be updated or will expire sooner
than the specified minimum time, False otherwise.
"""
seconds_to_expiry, registered_endpoints, registered_protocols = (
self.query_agent_record(address)
)
return (
not self.is_registered(address)
or seconds_to_expiry < min_seconds_left
or endpoints != registered_endpoints
or protocols != registered_protocols
)

def query_agent_record(
self, address: str
) -> Tuple[int, List[AgentEndpoint], List[str]]:
"""
Get the expiry height of an agent's registration.
Get the records associated with an agent's registration.

Args:
address (str): The agent's address.

Returns:
int: The expiry height of the agent's registration.
Tuple[int, List[AgentEndpoint], List[str]]: The expiry height of the agent's
registration, the agent's endpoints, and the agent's protocols.
"""
query_msg = {"query_records": {"agent_address": address}}
response = self.query_contract(query_msg)

if not response.get("record"):
return []

if not response.get("record"):
contract_state = self.query_contract({"query_contract_state": {}})
expiry = contract_state.get("state", {}).get("expiry_height", 0)
return expiry * AVERAGE_BLOCK_INTERVAL

expiry = response["record"][0].get("expiry", 0)
height = response.get("height", 0)
expiry_block = response["record"][0].get("expiry", 0)
current_block = response.get("height", 0)

return (expiry - height) * AVERAGE_BLOCK_INTERVAL
seconds_to_expiry = (expiry_block - current_block) * AVERAGE_BLOCK_INTERVAL

def get_endpoints(self, address: str) -> List[AgentEndpoint]:
endpoints = []
for endpoint in response["record"][0]["record"]["service"]["endpoints"]:
endpoints.append(AgentEndpoint.model_validate(endpoint))

protocols = response["record"][0]["record"]["service"]["protocols"]

return seconds_to_expiry, endpoints, protocols

def get_expiry(self, address: str) -> int:
"""
Get the endpoints associated with an agent's registration.
Get the approximate seconds to expiry of an agent's registration.

Args:
address (str): The agent's address.

Returns:
List[AgentEndpoint]: The endpoints associated with the agent's registration.
int: The approximate seconds to expiry of the agent's registration.
"""
query_msg = {"query_records": {"agent_address": address}}
response = self.query_contract(query_msg)
return self.query_agent_record(address)[0]

if not response.get("record"):
return []
def get_endpoints(self, address: str) -> List[AgentEndpoint]:
"""
Get the endpoints associated with an agent's registration.

endpoints = []
for endpoint in response["record"][0]["record"]["service"]["endpoints"]:
endpoints.append(AgentEndpoint.model_validate(endpoint))
Args:
address (str): The agent's address.

return endpoints
Returns:
List[AgentEndpoint]: The agent's registered endpoints.
"""
return self.query_agent_record(address)[1]

def get_protocols(self, address: str):
def get_protocols(self, address: str) -> List[str]:
"""
Get the protocols associated with an agent's registration.

Args:
address (str): The agent's address.

Returns:
Any: The protocols associated with the agent's registration.
List[str]: The agent's registered protocols.
"""
query_msg = {"query_records": {"agent_address": address}}
response = self.query_contract(query_msg)

if not response.get("record"):
return None

return response["record"][0]["record"]["service"]["protocols"]
return self.query_agent_record(address)[2]

def get_registration_msg(
self,
61 changes: 60 additions & 1 deletion python/src/uagents/registration.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional

import aiohttp
from cosmpy.aerial.client import LedgerClient
@@ -88,6 +88,13 @@ def _build_digest(self) -> bytes:
return sha256.digest()


class AgentLedgerRegistrationDetails(BaseModel):
agent_address: str
protocols: List[str]
endpoints: List[AgentEndpoint]
signer: Callable[[], str]


class AlmanacApiRegistrationPolicy(AgentRegistrationPolicy):
def __init__(
self,
@@ -294,6 +301,58 @@ def _sign_registration(self, agent_address: str) -> str:
)


class BatchLedgerRegistrationPolicy(BatchRegistrationPolicy):
def __init__(
self,
ledger: LedgerClient,
wallet: LocalWallet,
almanac_contract: AlmanacContract,
testnet: bool,
logger: Optional[logging.Logger] = None,
):
self._ledger = ledger
self._wallet = wallet
self._almanac_contract = almanac_contract
self._testnet = testnet
self._logger = logger or logging.getLogger(__name__)
self._agents: List[AgentLedgerRegistrationDetails] = []

def add_agent(self, agent: Any):
agent_details = AgentLedgerRegistrationDetails(
agent_address=agent.address,
protocols=list(agent.protocols.keys()),
endpoints=agent._endpoints,
sign_registration=agent._identity.sign_registration,
)
self._agents.append(agent_details)

def _get_balance(self) -> int:
return self._ledger.query_bank_balance(Address(self._wallet.address()))

async def register(self):
self._logger.info("Registering agents on Almanac contract...")
for agent in self._agents:
if self._almanac_contract.registration_needs_update(
agent.agent_address,
agent.protocols,
agent.endpoints,
REGISTRATION_UPDATE_INTERVAL_SECONDS,
):
signature = agent.sign_registration(agent.agent_address)
await self._almanac_contract.register(
self._ledger,
self._wallet,
agent.agent_address,
agent.protocols,
agent.endpoints,
signature,
)
else:
self._logger.info(
f"Agent {agent.agent_address} registration is up to date!"
)


class DefaultRegistrationPolicy(AgentRegistrationPolicy):
def __init__(
self,