diff --git a/src/driftpy/accounts/oracle.py b/src/driftpy/accounts/oracle.py index abea895e..903fcb0a 100644 --- a/src/driftpy/accounts/oracle.py +++ b/src/driftpy/accounts/oracle.py @@ -229,12 +229,17 @@ def decode_swb_price_info(data: bytes): def decode_prelaunch_price_info(data: bytes): - prelaunch_oracle = DRIFT_CODER.accounts.decode(data) + decoded_account = DRIFT_CODER.accounts.decode(data) + + if not hasattr(decoded_account, "amm_last_update_slot"): + raise ValueError( + "Decoded account does not have amm_last_update_slot attribute, not a PrelaunchOracle" + ) return OraclePriceData( - price=prelaunch_oracle.price, - slot=prelaunch_oracle.amm_last_update_slot, - confidence=prelaunch_oracle.confidence, + price=decoded_account.price, + slot=decoded_account.amm_last_update_slot, + confidence=decoded_account.confidence, has_sufficient_number_of_data_points=True, twap=None, twap_confidence=None, diff --git a/src/driftpy/accounts/ws/drift_client.py b/src/driftpy/accounts/ws/drift_client.py index 91647a8e..5bcfd8d2 100644 --- a/src/driftpy/accounts/ws/drift_client.py +++ b/src/driftpy/accounts/ws/drift_client.py @@ -11,6 +11,7 @@ FullOracleWrapper, ) from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber +from driftpy.accounts.ws.multi_account_subscriber import WebsocketMultiAccountSubscriber from driftpy.addresses import ( Pubkey, get_perp_market_public_key, @@ -53,7 +54,8 @@ def __init__( self.state_subscriber = None self.spot_market_subscribers = {} self.perp_market_subscribers = {} - self.oracle_subscribers = {} + self.oracle_subscriber = WebsocketMultiAccountSubscriber(program, commitment) + self.oracle_id_to_pubkey: dict[str, Pubkey] = {} self.spot_market_map = None self.perp_market_map = None self.spot_market_oracle_map: dict[int, Pubkey] = {} @@ -112,6 +114,7 @@ async def subscribe(self): for full_oracle_wrapper in self.full_oracle_wrappers: await self.subscribe_to_oracle(full_oracle_wrapper) + await self.oracle_subscriber.subscribe() await spot_market_map.subscribe() await perp_market_map.subscribe() @@ -125,6 +128,8 @@ async def subscribe(self): for full_oracle_wrapper in self.full_oracle_wrappers: await self.subscribe_to_oracle_info(full_oracle_wrapper) + await self.oracle_subscriber.subscribe() + await self._set_perp_oracle_map() await self._set_spot_oracle_map() @@ -176,18 +181,15 @@ async def subscribe_to_oracle(self, full_oracle_wrapper: FullOracleWrapper): full_oracle_wrapper.pubkey, full_oracle_wrapper.oracle_source, ) - if oracle_id in self.oracle_subscribers: + if full_oracle_wrapper.pubkey in self.oracle_subscriber.data_map: return - oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData]( + await self.oracle_subscriber.add_account( full_oracle_wrapper.pubkey, - self.program, - self.commitment, get_oracle_decode_fn(full_oracle_wrapper.oracle_source), initial_data=full_oracle_wrapper.oracle_price_data_and_slot, ) - await oracle_subscriber.subscribe() - self.oracle_subscribers[oracle_id] = oracle_subscriber + self.oracle_id_to_pubkey[oracle_id] = full_oracle_wrapper.pubkey async def subscribe_to_oracle_info( self, oracle_info: OracleInfo | FullOracleWrapper @@ -202,18 +204,14 @@ async def subscribe_to_oracle_info( if oracle_info.pubkey == Pubkey.default(): return - if oracle_id in self.oracle_subscribers: + if oracle_info.pubkey in self.oracle_subscriber.data_map: return - oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData]( + await self.oracle_subscriber.add_account( oracle_info.pubkey, - self.program, - self.commitment, get_oracle_decode_fn(source), ) - - await oracle_subscriber.subscribe() - self.oracle_subscribers[oracle_id] = oracle_subscriber + self.oracle_id_to_pubkey[oracle_id] = oracle_info.pubkey def is_subscribed(self): return ( @@ -232,14 +230,12 @@ async def fetch(self): for spot_market_subscriber in self.spot_market_subscribers.values(): tasks.append(spot_market_subscriber.fetch()) - for oracle_subscriber in self.oracle_subscribers.values(): - tasks.append(oracle_subscriber.fetch()) + tasks.append(self.oracle_subscriber.fetch()) await asyncio.gather(*tasks) async def add_oracle(self, oracle_info: OracleInfo): - oracle_id = get_oracle_id(oracle_info.pubkey, oracle_info.source) - if oracle_id in self.oracle_subscribers: + if oracle_info.pubkey in self.oracle_subscriber.data_map: return True if oracle_info.pubkey == Pubkey.default(): @@ -271,7 +267,10 @@ def get_spot_market_and_slot( def get_oracle_price_data_and_slot( self, oracle_id: str ) -> Optional[DataAndSlot[OraclePriceData]]: - return self.oracle_subscribers[oracle_id].data_and_slot + pubkey = self.oracle_id_to_pubkey.get(oracle_id) + if pubkey is None: + return None + return self.oracle_subscriber.get_data(pubkey) async def unsubscribe(self): if self.is_subscribed(): @@ -284,8 +283,7 @@ async def unsubscribe(self): await spot_market_subscriber.unsubscribe() for perp_market_subscriber in self.perp_market_subscribers.values(): await perp_market_subscriber.unsubscribe() - for oracle_subscriber in self.oracle_subscribers.values(): - await oracle_subscriber.unsubscribe() + await self.oracle_subscriber.unsubscribe() def get_market_accounts_and_slots(self) -> list[DataAndSlot[PerpMarketAccount]]: if self.perp_market_map: @@ -317,7 +315,7 @@ async def _set_perp_oracle_map(self): market_index = perp_market_account.market_index oracle = perp_market_account.amm.oracle oracle_id = get_oracle_id(oracle, perp_market_account.amm.oracle_source) - if oracle_id not in self.oracle_subscribers: + if oracle not in self.oracle_subscriber.data_map: await self.add_oracle( OracleInfo(oracle, perp_market_account.amm.oracle_source) ) @@ -334,7 +332,7 @@ async def _set_spot_oracle_map(self): market_index = spot_market_account.market_index oracle = spot_market_account.oracle oracle_id = get_oracle_id(oracle, spot_market_account.oracle_source) - if oracle_id not in self.oracle_subscribers: + if oracle not in self.oracle_subscriber.data_map: await self.add_oracle( OracleInfo(oracle, spot_market_account.oracle_source) ) diff --git a/src/driftpy/accounts/ws/multi_account_subscriber.py b/src/driftpy/accounts/ws/multi_account_subscriber.py new file mode 100644 index 00000000..5d9873ea --- /dev/null +++ b/src/driftpy/accounts/ws/multi_account_subscriber.py @@ -0,0 +1,278 @@ +import asyncio +from typing import Any, Callable, Dict, Optional, cast + +import websockets +import websockets.exceptions # force eager imports +from anchorpy.program.core import Program +from solana.rpc.commitment import Commitment +from solana.rpc.websocket_api import SolanaWsClientProtocol, connect +from solders.pubkey import Pubkey + +from driftpy.accounts import DataAndSlot, get_account_data_and_slot +from driftpy.types import get_ws_url + + +class WebsocketMultiAccountSubscriber: + def __init__( + self, + program: Program, + commitment: Commitment = Commitment("confirmed"), + ): + self.program = program + self.commitment = commitment + self.ws: Optional[SolanaWsClientProtocol] = None + self.task: Optional[asyncio.Task] = None + + self.subscription_map: Dict[int, Pubkey] = {} + self.pubkey_to_subscription: Dict[Pubkey, int] = {} + self.decode_map: Dict[Pubkey, Callable[[bytes], Any]] = {} + self.data_map: Dict[Pubkey, Optional[DataAndSlot]] = {} + self.initial_data_map: Dict[Pubkey, Optional[DataAndSlot]] = {} + self.pending_subscriptions: list[Pubkey] = [] + + self._lock = asyncio.Lock() + + async def add_account( + self, + pubkey: Pubkey, + decode: Optional[Callable[[bytes], Any]] = None, + initial_data: Optional[DataAndSlot] = None, + ): + decode_fn = decode if decode is not None else self.program.coder.accounts.decode + + async with self._lock: + if pubkey in self.pubkey_to_subscription: + return + if pubkey in self.data_map and initial_data is None: + initial_data = self.data_map[pubkey] + + if initial_data is None: + try: + initial_data = await get_account_data_and_slot( + pubkey, self.program, self.commitment, decode_fn + ) + except Exception as e: + print(f"Error fetching initial data for {pubkey}: {e}") + return + + async with self._lock: + if pubkey in self.pubkey_to_subscription: + return + + self.decode_map[pubkey] = decode_fn + self.initial_data_map[pubkey] = initial_data + self.data_map[pubkey] = initial_data + + if self.ws is not None: + try: + # Enqueue before sending to maintain order + async with self._lock: + self.pending_subscriptions.append(pubkey) + + await self.ws.account_subscribe( + pubkey, + commitment=self.commitment, + encoding="base64", + ) + except Exception as e: + print(f"Error subscribing to account {pubkey}: {e}") + async with self._lock: + if ( + self.pending_subscriptions + and self.pending_subscriptions[-1] == pubkey + ): + self.pending_subscriptions.pop() + elif pubkey in self.pending_subscriptions: + self.pending_subscriptions.remove(pubkey) + + async def remove_account(self, pubkey: Pubkey): + async with self._lock: + if pubkey not in self.pubkey_to_subscription: + return + + subscription_id = self.pubkey_to_subscription[pubkey] + + if self.ws is not None: + try: + await self.ws.account_unsubscribe(subscription_id) + except Exception: + pass + + del self.subscription_map[subscription_id] + del self.pubkey_to_subscription[pubkey] + del self.decode_map[pubkey] + del self.data_map[pubkey] + if pubkey in self.initial_data_map: + del self.initial_data_map[pubkey] + + async def subscribe(self): + if self.task is not None: + return + + self.task = asyncio.create_task(self._subscribe_ws()) + + async def _subscribe_ws(self): + endpoint = self.program.provider.connection._provider.endpoint_uri + ws_endpoint = get_ws_url(endpoint) + + async for ws in connect(ws_endpoint): + try: + self.ws = cast(SolanaWsClientProtocol, ws) + + async with self._lock: + initial_accounts = [] + for pubkey in list(self.data_map.keys()): + if pubkey not in self.pubkey_to_subscription: + initial_accounts.append(pubkey) + + self.pending_subscriptions.extend(initial_accounts) + + for pubkey in initial_accounts: + try: + await ws.account_subscribe( + pubkey, + commitment=self.commitment, + encoding="base64", + ) + except Exception as e: + print(f"Error subscribing to account {pubkey}: {e}") + async with self._lock: + if pubkey in self.pending_subscriptions: + self.pending_subscriptions.remove(pubkey) + + async for msg in ws: + try: + if len(msg) == 0: + print("No message received") + continue + + result = msg[0].result + + if isinstance(result, int): + async with self._lock: + if self.pending_subscriptions: + pubkey = self.pending_subscriptions.pop(0) + subscription_id = result + self.subscription_map[subscription_id] = pubkey + self.pubkey_to_subscription[pubkey] = ( + subscription_id + ) + else: + print( + "No pending subscriptions but got a confirmation. " + "This implies a race condition or mismatch." + ) + continue + + if hasattr(result, "value") and result.value is not None: + subscription_id = None + if hasattr(msg[0], "subscription"): + subscription_id = msg[0].subscription + + if ( + subscription_id is None + or subscription_id not in self.subscription_map + ): + print( + f"Subscription ID {subscription_id} not found in subscription map" + ) + continue + + pubkey = self.subscription_map[subscription_id] + decode_fn = self.decode_map.get(pubkey) + + if decode_fn is None: + print(f"No decode function found for pubkey {pubkey}") + continue + + try: + slot = int(result.context.slot) + account_bytes = cast(bytes, result.value.data) + decoded_data = decode_fn(account_bytes) + new_data = DataAndSlot(slot, decoded_data) + self._update_data(pubkey, new_data) + except Exception: + # this is RPC noise? + continue + except Exception as e: + print(f"Error processing websocket message: {e}") + continue + + except websockets.exceptions.ConnectionClosed: + self.ws = None + async with self._lock: + self.subscription_map.clear() + self.pubkey_to_subscription.clear() + continue + except Exception as e: + print(f"Error in websocket connection: {e}") + self.ws = None + async with self._lock: + self.subscription_map.clear() + self.pubkey_to_subscription.clear() + await asyncio.sleep(1) + continue + + def _update_data(self, pubkey: Pubkey, new_data: Optional[DataAndSlot]): + if new_data is None: + return + + current_data = self.data_map.get(pubkey) + if current_data is None or new_data.slot >= current_data.slot: + self.data_map[pubkey] = new_data + + def get_data(self, pubkey: Pubkey) -> Optional[DataAndSlot]: + return self.data_map.get(pubkey) + + async def fetch(self, pubkey: Optional[Pubkey] = None): + if pubkey is not None: + decode_fn = self.decode_map.get(pubkey) + if decode_fn is None: + return + new_data = await get_account_data_and_slot( + pubkey, self.program, self.commitment, decode_fn + ) + self._update_data(pubkey, new_data) + else: + tasks = [] + for pubkey, decode_fn in self.decode_map.items(): + tasks.append( + get_account_data_and_slot( + pubkey, self.program, self.commitment, decode_fn + ) + ) + results = await asyncio.gather(*tasks, return_exceptions=True) + for pubkey, result in zip(self.decode_map.keys(), results): + if isinstance(result, Exception): + continue + self._update_data(pubkey, result) + + def is_subscribed(self): + return self.ws is not None and self.task is not None + + async def unsubscribe(self): + if self.task: + self.task.cancel() + try: + await self.task + except asyncio.CancelledError: + pass + self.task = None + + if self.ws: + async with self._lock: + for subscription_id in list(self.subscription_map.keys()): + try: + await self.ws.account_unsubscribe(subscription_id) + except Exception: + pass + await self.ws.close() + self.ws = None + + async with self._lock: + self.subscription_map.clear() + self.pubkey_to_subscription.clear() + self.decode_map.clear() + self.data_map.clear() + self.initial_data_map.clear() + self.pending_subscriptions.clear() diff --git a/src/driftpy/constants/spot_markets.py b/src/driftpy/constants/spot_markets.py index 84b71657..64da3957 100644 --- a/src/driftpy/constants/spot_markets.py +++ b/src/driftpy/constants/spot_markets.py @@ -19,7 +19,6 @@ class SpotMarketConfig: WRAPPED_SOL_MINT = Pubkey.from_string("So11111111111111111111111111111111111111112") - devnet_spot_market_configs: list[SpotMarketConfig] = [ SpotMarketConfig( symbol="USDC", @@ -115,8 +114,8 @@ class SpotMarketConfig: SpotMarketConfig( symbol="mSOL", market_index=2, - oracle=Pubkey.from_string("FAq7hqjn7FWGXKDwJHzsXGgBcydGTcK4kziJpAGWXjDb"), - oracle_source=OracleSource.PythPull(), # type: ignore + oracle=Pubkey.from_string("FY2JMi1vYz1uayVT2GJ96ysZgpagjhdPRG2upNPtSZsC"), + oracle_source=OracleSource.PythLazer(), # type: ignore mint=Pubkey.from_string("mSoLzYCxHdYgdzU16g5QSh3i5K3z3KZK7ytfqcJm7So"), decimals=9, ), @@ -147,8 +146,8 @@ class SpotMarketConfig: SpotMarketConfig( symbol="jitoSOL", market_index=6, - oracle=Pubkey.from_string("9QE1P5EfzthYDgoQ9oPeTByCEKaRJeZbVVqKJfgU9iau"), - oracle_source=OracleSource.PythPull(), # type: ignore + oracle=Pubkey.from_string("2cHCtAkMnttMh3bNKSCgSKSP5D4yN3p8bfnMdS3VZsDf"), + oracle_source=OracleSource.PythLazer(), # type: ignore mint=Pubkey.from_string("J1toso1uCk3RLmjorhTtrVwY9HJ7X8V9yYac6Y7kGCPn"), decimals=9, ), @@ -291,16 +290,16 @@ class SpotMarketConfig: SpotMarketConfig( symbol="sUSDe", market_index=24, - oracle=Pubkey.from_string("BRuNuzLAPHHGSSVAJPKMcmJMdgDfrekvnSxkxPDGdeqp"), - oracle_source=OracleSource.PythPull(), # type: ignore + oracle=Pubkey.from_string("CX7JCXtUTiC43ZA4uzoH7iQBD15jtVwdBNCnjKHt1BrQ"), + oracle_source=OracleSource.PythLazer(), # type: ignore mint=Pubkey.from_string("Eh6XEPhSwoLv5wFApukmnaVSHQ6sAnoD9BmgmwQoN2sN"), decimals=9, ), SpotMarketConfig( symbol="BNSOL", market_index=25, - oracle=Pubkey.from_string("8DmXTfhhtb9kTcpTVfb6Ygx8WhZ8wexGqcpxfn23zooe"), - oracle_source=OracleSource.PythPull(), # type: ignore + oracle=Pubkey.from_string("2LxMbHBHsw74aE3XgfthmUNkdDfUGcSEy3G3D3t642fd"), + oracle_source=OracleSource.PythLazer(), # type: ignore mint=Pubkey.from_string("BNso1VUJnh4zcfpZa6986Ea66P6TCp59hvtNJ8b1X85"), decimals=9, ), @@ -323,8 +322,8 @@ class SpotMarketConfig: SpotMarketConfig( symbol="USDS", market_index=28, - oracle=Pubkey.from_string("7pT9mxKXyvfaZKeKy1oe2oV2K1RFtF7tPEJHUY3h2vVV"), - oracle_source=OracleSource.PythStableCoinPull(), # type: ignore + oracle=Pubkey.from_string("5Km85n3s9Zs5wEoXYWuHbpoDzst4EBkS5f1XuQJGG1DL"), + oracle_source=OracleSource.PythLazerStableCoin(), # type: ignore mint=Pubkey.from_string("USDSwr9ApdHk5bvJKMjzff41FfuX8bSxdKcR81vTwcA"), decimals=6, ), @@ -379,8 +378,8 @@ class SpotMarketConfig: SpotMarketConfig( symbol="AI16Z", market_index=35, - oracle=Pubkey.from_string("3BGheQVvYtBNpBKSUXSTjpyKQc3dh8iiwT91Aiq7KYCU"), - oracle_source=OracleSource.PythLazer(), # type: ignore + oracle=Pubkey.from_string("BHqLyA9ov1VPNzt8eb5bt75X2Vk1EVKw1d9Qa78Gk5tR"), + oracle_source=OracleSource.SwitchboardOnDemand(), # type: ignore mint=Pubkey.from_string("HeLp6NuQkmYB4pYWo2zYs22mESHXPQYzXbB8n4V98jwC"), decimals=9, ), @@ -403,8 +402,8 @@ class SpotMarketConfig: SpotMarketConfig( symbol="AUSD", market_index=38, - oracle=Pubkey.from_string("8FZhpiM8n3mpgvENWLcEvHsKB1bBhYBAyL4Ypr4gptLZ"), - oracle_source=OracleSource.PythStableCoinPull(), # type: ignore + oracle=Pubkey.from_string("9JYpqJfLXgrW8Wqzfd93GvJF73m2jJFjNqpQv3wQtehZ"), + oracle_source=OracleSource.PythLazerStableCoin(), # type: ignore mint=Pubkey.from_string("AUSD1jCcCyPLybk1YnvPWsHQSrZ46dxwoMniN4N2UEB9"), decimals=6, ), @@ -459,8 +458,8 @@ class SpotMarketConfig: SpotMarketConfig( symbol="zBTC", market_index=45, - oracle=Pubkey.from_string("CN9QvvbGQzMnN8vJaSek2so4vFnTqgJDFrdJB8Y4tQfB"), - oracle_source=OracleSource.PythPull(), # type: ignore + oracle=Pubkey.from_string("3xcpvBUVV8ALVV4Wod733Vyic3fe8iJAeXDpRdk19Z3p"), + oracle_source=OracleSource.PythLazer(), # type: ignore mint=Pubkey.from_string("zBTCug3er3tLyffELcvDNrKkCymbPWysGcWihESYfLg"), decimals=8, ), @@ -584,4 +583,12 @@ class SpotMarketConfig: mint=Pubkey.from_string("METvsvVRapdj9cFLzq4Tr43xK4tAjQfwX76z3n6mWQL"), decimals=6, ), + SpotMarketConfig( + symbol="CASH", + market_index=61, + oracle=Pubkey.from_string("AK6coxSjfAnuDT4ZUSP3UpeQe2G1tKcALnsdd835eg7T"), + oracle_source=OracleSource.PythLazerStableCoin(), # type: ignore + mint=Pubkey.from_string("CASHx9KJUStyftLFWGvEVf59SGeG9sh5FfcnZMVPCASH"), + decimals=6, + ), ]