Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/driftpy/accounts/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,17 @@ def decode_swb_price_info(data: bytes):


def decode_prelaunch_price_info(data: bytes):
prelaunch_oracle = DRIFT_CODER.accounts.decode(data)
decoded_account = DRIFT_CODER.accounts.decode(data)

if not hasattr(decoded_account, "amm_last_update_slot"):
raise ValueError(
"Decoded account does not have amm_last_update_slot attribute, not a PrelaunchOracle"
)

return OraclePriceData(
price=prelaunch_oracle.price,
slot=prelaunch_oracle.amm_last_update_slot,
confidence=prelaunch_oracle.confidence,
price=decoded_account.price,
slot=decoded_account.amm_last_update_slot,
confidence=decoded_account.confidence,
has_sufficient_number_of_data_points=True,
twap=None,
twap_confidence=None,
Expand Down
44 changes: 21 additions & 23 deletions src/driftpy/accounts/ws/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
FullOracleWrapper,
)
from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber
from driftpy.accounts.ws.multi_account_subscriber import WebsocketMultiAccountSubscriber
from driftpy.addresses import (
Pubkey,
get_perp_market_public_key,
Expand Down Expand Up @@ -53,7 +54,8 @@ def __init__(
self.state_subscriber = None
self.spot_market_subscribers = {}
self.perp_market_subscribers = {}
self.oracle_subscribers = {}
self.oracle_subscriber = WebsocketMultiAccountSubscriber(program, commitment)
self.oracle_id_to_pubkey: dict[str, Pubkey] = {}
self.spot_market_map = None
self.perp_market_map = None
self.spot_market_oracle_map: dict[int, Pubkey] = {}
Expand Down Expand Up @@ -112,6 +114,7 @@ async def subscribe(self):
for full_oracle_wrapper in self.full_oracle_wrappers:
await self.subscribe_to_oracle(full_oracle_wrapper)

await self.oracle_subscriber.subscribe()
await spot_market_map.subscribe()
await perp_market_map.subscribe()

Expand All @@ -125,6 +128,8 @@ async def subscribe(self):
for full_oracle_wrapper in self.full_oracle_wrappers:
await self.subscribe_to_oracle_info(full_oracle_wrapper)

await self.oracle_subscriber.subscribe()

await self._set_perp_oracle_map()
await self._set_spot_oracle_map()

Expand Down Expand Up @@ -176,18 +181,15 @@ async def subscribe_to_oracle(self, full_oracle_wrapper: FullOracleWrapper):
full_oracle_wrapper.pubkey,
full_oracle_wrapper.oracle_source,
)
if oracle_id in self.oracle_subscribers:
if full_oracle_wrapper.pubkey in self.oracle_subscriber.data_map:
return

oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData](
await self.oracle_subscriber.add_account(
full_oracle_wrapper.pubkey,
self.program,
self.commitment,
get_oracle_decode_fn(full_oracle_wrapper.oracle_source),
initial_data=full_oracle_wrapper.oracle_price_data_and_slot,
)
await oracle_subscriber.subscribe()
self.oracle_subscribers[oracle_id] = oracle_subscriber
self.oracle_id_to_pubkey[oracle_id] = full_oracle_wrapper.pubkey

async def subscribe_to_oracle_info(
self, oracle_info: OracleInfo | FullOracleWrapper
Expand All @@ -202,18 +204,14 @@ async def subscribe_to_oracle_info(
if oracle_info.pubkey == Pubkey.default():
return

if oracle_id in self.oracle_subscribers:
if oracle_info.pubkey in self.oracle_subscriber.data_map:
return

oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData](
await self.oracle_subscriber.add_account(
oracle_info.pubkey,
self.program,
self.commitment,
get_oracle_decode_fn(source),
)

await oracle_subscriber.subscribe()
self.oracle_subscribers[oracle_id] = oracle_subscriber
self.oracle_id_to_pubkey[oracle_id] = oracle_info.pubkey

def is_subscribed(self):
return (
Expand All @@ -232,14 +230,12 @@ async def fetch(self):
for spot_market_subscriber in self.spot_market_subscribers.values():
tasks.append(spot_market_subscriber.fetch())

for oracle_subscriber in self.oracle_subscribers.values():
tasks.append(oracle_subscriber.fetch())
tasks.append(self.oracle_subscriber.fetch())

await asyncio.gather(*tasks)

async def add_oracle(self, oracle_info: OracleInfo):
oracle_id = get_oracle_id(oracle_info.pubkey, oracle_info.source)
if oracle_id in self.oracle_subscribers:
if oracle_info.pubkey in self.oracle_subscriber.data_map:
return True

if oracle_info.pubkey == Pubkey.default():
Expand Down Expand Up @@ -271,7 +267,10 @@ def get_spot_market_and_slot(
def get_oracle_price_data_and_slot(
self, oracle_id: str
) -> Optional[DataAndSlot[OraclePriceData]]:
return self.oracle_subscribers[oracle_id].data_and_slot
pubkey = self.oracle_id_to_pubkey.get(oracle_id)
if pubkey is None:
return None
return self.oracle_subscriber.get_data(pubkey)

async def unsubscribe(self):
if self.is_subscribed():
Expand All @@ -284,8 +283,7 @@ async def unsubscribe(self):
await spot_market_subscriber.unsubscribe()
for perp_market_subscriber in self.perp_market_subscribers.values():
await perp_market_subscriber.unsubscribe()
for oracle_subscriber in self.oracle_subscribers.values():
await oracle_subscriber.unsubscribe()
await self.oracle_subscriber.unsubscribe()

def get_market_accounts_and_slots(self) -> list[DataAndSlot[PerpMarketAccount]]:
if self.perp_market_map:
Expand Down Expand Up @@ -317,7 +315,7 @@ async def _set_perp_oracle_map(self):
market_index = perp_market_account.market_index
oracle = perp_market_account.amm.oracle
oracle_id = get_oracle_id(oracle, perp_market_account.amm.oracle_source)
if oracle_id not in self.oracle_subscribers:
if oracle not in self.oracle_subscriber.data_map:
await self.add_oracle(
OracleInfo(oracle, perp_market_account.amm.oracle_source)
)
Expand All @@ -334,7 +332,7 @@ async def _set_spot_oracle_map(self):
market_index = spot_market_account.market_index
oracle = spot_market_account.oracle
oracle_id = get_oracle_id(oracle, spot_market_account.oracle_source)
if oracle_id not in self.oracle_subscribers:
if oracle not in self.oracle_subscriber.data_map:
await self.add_oracle(
OracleInfo(oracle, spot_market_account.oracle_source)
)
Expand Down
Loading
Loading