Skip to content

Commit de27e00

Browse files
authored
Merge pull request #257 from drift-labs/sina/less-subscribers
Less websockets
2 parents 75ab18f + 172bd02 commit de27e00

File tree

4 files changed

+332
-44
lines changed

4 files changed

+332
-44
lines changed

src/driftpy/accounts/oracle.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,12 +229,17 @@ def decode_swb_price_info(data: bytes):
229229

230230

231231
def decode_prelaunch_price_info(data: bytes):
232-
prelaunch_oracle = DRIFT_CODER.accounts.decode(data)
232+
decoded_account = DRIFT_CODER.accounts.decode(data)
233+
234+
if not hasattr(decoded_account, "amm_last_update_slot"):
235+
raise ValueError(
236+
"Decoded account does not have amm_last_update_slot attribute, not a PrelaunchOracle"
237+
)
233238

234239
return OraclePriceData(
235-
price=prelaunch_oracle.price,
236-
slot=prelaunch_oracle.amm_last_update_slot,
237-
confidence=prelaunch_oracle.confidence,
240+
price=decoded_account.price,
241+
slot=decoded_account.amm_last_update_slot,
242+
confidence=decoded_account.confidence,
238243
has_sufficient_number_of_data_points=True,
239244
twap=None,
240245
twap_confidence=None,

src/driftpy/accounts/ws/drift_client.py

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
FullOracleWrapper,
1212
)
1313
from driftpy.accounts.ws.account_subscriber import WebsocketAccountSubscriber
14+
from driftpy.accounts.ws.multi_account_subscriber import WebsocketMultiAccountSubscriber
1415
from driftpy.addresses import (
1516
Pubkey,
1617
get_perp_market_public_key,
@@ -53,7 +54,8 @@ def __init__(
5354
self.state_subscriber = None
5455
self.spot_market_subscribers = {}
5556
self.perp_market_subscribers = {}
56-
self.oracle_subscribers = {}
57+
self.oracle_subscriber = WebsocketMultiAccountSubscriber(program, commitment)
58+
self.oracle_id_to_pubkey: dict[str, Pubkey] = {}
5759
self.spot_market_map = None
5860
self.perp_market_map = None
5961
self.spot_market_oracle_map: dict[int, Pubkey] = {}
@@ -112,6 +114,7 @@ async def subscribe(self):
112114
for full_oracle_wrapper in self.full_oracle_wrappers:
113115
await self.subscribe_to_oracle(full_oracle_wrapper)
114116

117+
await self.oracle_subscriber.subscribe()
115118
await spot_market_map.subscribe()
116119
await perp_market_map.subscribe()
117120

@@ -125,6 +128,8 @@ async def subscribe(self):
125128
for full_oracle_wrapper in self.full_oracle_wrappers:
126129
await self.subscribe_to_oracle_info(full_oracle_wrapper)
127130

131+
await self.oracle_subscriber.subscribe()
132+
128133
await self._set_perp_oracle_map()
129134
await self._set_spot_oracle_map()
130135

@@ -176,18 +181,15 @@ async def subscribe_to_oracle(self, full_oracle_wrapper: FullOracleWrapper):
176181
full_oracle_wrapper.pubkey,
177182
full_oracle_wrapper.oracle_source,
178183
)
179-
if oracle_id in self.oracle_subscribers:
184+
if full_oracle_wrapper.pubkey in self.oracle_subscriber.data_map:
180185
return
181186

182-
oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData](
187+
await self.oracle_subscriber.add_account(
183188
full_oracle_wrapper.pubkey,
184-
self.program,
185-
self.commitment,
186189
get_oracle_decode_fn(full_oracle_wrapper.oracle_source),
187190
initial_data=full_oracle_wrapper.oracle_price_data_and_slot,
188191
)
189-
await oracle_subscriber.subscribe()
190-
self.oracle_subscribers[oracle_id] = oracle_subscriber
192+
self.oracle_id_to_pubkey[oracle_id] = full_oracle_wrapper.pubkey
191193

192194
async def subscribe_to_oracle_info(
193195
self, oracle_info: OracleInfo | FullOracleWrapper
@@ -202,18 +204,14 @@ async def subscribe_to_oracle_info(
202204
if oracle_info.pubkey == Pubkey.default():
203205
return
204206

205-
if oracle_id in self.oracle_subscribers:
207+
if oracle_info.pubkey in self.oracle_subscriber.data_map:
206208
return
207209

208-
oracle_subscriber = WebsocketAccountSubscriber[OraclePriceData](
210+
await self.oracle_subscriber.add_account(
209211
oracle_info.pubkey,
210-
self.program,
211-
self.commitment,
212212
get_oracle_decode_fn(source),
213213
)
214-
215-
await oracle_subscriber.subscribe()
216-
self.oracle_subscribers[oracle_id] = oracle_subscriber
214+
self.oracle_id_to_pubkey[oracle_id] = oracle_info.pubkey
217215

218216
def is_subscribed(self):
219217
return (
@@ -232,14 +230,12 @@ async def fetch(self):
232230
for spot_market_subscriber in self.spot_market_subscribers.values():
233231
tasks.append(spot_market_subscriber.fetch())
234232

235-
for oracle_subscriber in self.oracle_subscribers.values():
236-
tasks.append(oracle_subscriber.fetch())
233+
tasks.append(self.oracle_subscriber.fetch())
237234

238235
await asyncio.gather(*tasks)
239236

240237
async def add_oracle(self, oracle_info: OracleInfo):
241-
oracle_id = get_oracle_id(oracle_info.pubkey, oracle_info.source)
242-
if oracle_id in self.oracle_subscribers:
238+
if oracle_info.pubkey in self.oracle_subscriber.data_map:
243239
return True
244240

245241
if oracle_info.pubkey == Pubkey.default():
@@ -271,7 +267,10 @@ def get_spot_market_and_slot(
271267
def get_oracle_price_data_and_slot(
272268
self, oracle_id: str
273269
) -> Optional[DataAndSlot[OraclePriceData]]:
274-
return self.oracle_subscribers[oracle_id].data_and_slot
270+
pubkey = self.oracle_id_to_pubkey.get(oracle_id)
271+
if pubkey is None:
272+
return None
273+
return self.oracle_subscriber.get_data(pubkey)
275274

276275
async def unsubscribe(self):
277276
if self.is_subscribed():
@@ -284,8 +283,7 @@ async def unsubscribe(self):
284283
await spot_market_subscriber.unsubscribe()
285284
for perp_market_subscriber in self.perp_market_subscribers.values():
286285
await perp_market_subscriber.unsubscribe()
287-
for oracle_subscriber in self.oracle_subscribers.values():
288-
await oracle_subscriber.unsubscribe()
286+
await self.oracle_subscriber.unsubscribe()
289287

290288
def get_market_accounts_and_slots(self) -> list[DataAndSlot[PerpMarketAccount]]:
291289
if self.perp_market_map:
@@ -317,7 +315,7 @@ async def _set_perp_oracle_map(self):
317315
market_index = perp_market_account.market_index
318316
oracle = perp_market_account.amm.oracle
319317
oracle_id = get_oracle_id(oracle, perp_market_account.amm.oracle_source)
320-
if oracle_id not in self.oracle_subscribers:
318+
if oracle not in self.oracle_subscriber.data_map:
321319
await self.add_oracle(
322320
OracleInfo(oracle, perp_market_account.amm.oracle_source)
323321
)
@@ -334,7 +332,7 @@ async def _set_spot_oracle_map(self):
334332
market_index = spot_market_account.market_index
335333
oracle = spot_market_account.oracle
336334
oracle_id = get_oracle_id(oracle, spot_market_account.oracle_source)
337-
if oracle_id not in self.oracle_subscribers:
335+
if oracle not in self.oracle_subscriber.data_map:
338336
await self.add_oracle(
339337
OracleInfo(oracle, spot_market_account.oracle_source)
340338
)

0 commit comments

Comments
 (0)