1111 FullOracleWrapper ,
1212)
1313from driftpy .accounts .ws .account_subscriber import WebsocketAccountSubscriber
14+ from driftpy .accounts .ws .multi_account_subscriber import WebsocketMultiAccountSubscriber
1415from driftpy .addresses import (
1516 Pubkey ,
1617 get_perp_market_public_key ,
@@ -53,7 +54,8 @@ def __init__(
5354 self .state_subscriber = None
5455 self .spot_market_subscribers = {}
5556 self .perp_market_subscribers = {}
56- self .oracle_subscribers = {}
57+ self .oracle_subscriber = WebsocketMultiAccountSubscriber (program , commitment )
58+ self .oracle_id_to_pubkey : dict [str , Pubkey ] = {}
5759 self .spot_market_map = None
5860 self .perp_market_map = None
5961 self .spot_market_oracle_map : dict [int , Pubkey ] = {}
@@ -112,6 +114,7 @@ async def subscribe(self):
112114 for full_oracle_wrapper in self .full_oracle_wrappers :
113115 await self .subscribe_to_oracle (full_oracle_wrapper )
114116
117+ await self .oracle_subscriber .subscribe ()
115118 await spot_market_map .subscribe ()
116119 await perp_market_map .subscribe ()
117120
@@ -125,6 +128,8 @@ async def subscribe(self):
125128 for full_oracle_wrapper in self .full_oracle_wrappers :
126129 await self .subscribe_to_oracle_info (full_oracle_wrapper )
127130
131+ await self .oracle_subscriber .subscribe ()
132+
128133 await self ._set_perp_oracle_map ()
129134 await self ._set_spot_oracle_map ()
130135
@@ -176,18 +181,15 @@ async def subscribe_to_oracle(self, full_oracle_wrapper: FullOracleWrapper):
176181 full_oracle_wrapper .pubkey ,
177182 full_oracle_wrapper .oracle_source ,
178183 )
179- if oracle_id in self .oracle_subscribers :
184+ if full_oracle_wrapper . pubkey in self .oracle_subscriber . data_map :
180185 return
181186
182- oracle_subscriber = WebsocketAccountSubscriber [ OraclePriceData ] (
187+ await self . oracle_subscriber . add_account (
183188 full_oracle_wrapper .pubkey ,
184- self .program ,
185- self .commitment ,
186189 get_oracle_decode_fn (full_oracle_wrapper .oracle_source ),
187190 initial_data = full_oracle_wrapper .oracle_price_data_and_slot ,
188191 )
189- await oracle_subscriber .subscribe ()
190- self .oracle_subscribers [oracle_id ] = oracle_subscriber
192+ self .oracle_id_to_pubkey [oracle_id ] = full_oracle_wrapper .pubkey
191193
192194 async def subscribe_to_oracle_info (
193195 self , oracle_info : OracleInfo | FullOracleWrapper
@@ -202,18 +204,14 @@ async def subscribe_to_oracle_info(
202204 if oracle_info .pubkey == Pubkey .default ():
203205 return
204206
205- if oracle_id in self .oracle_subscribers :
207+ if oracle_info . pubkey in self .oracle_subscriber . data_map :
206208 return
207209
208- oracle_subscriber = WebsocketAccountSubscriber [ OraclePriceData ] (
210+ await self . oracle_subscriber . add_account (
209211 oracle_info .pubkey ,
210- self .program ,
211- self .commitment ,
212212 get_oracle_decode_fn (source ),
213213 )
214-
215- await oracle_subscriber .subscribe ()
216- self .oracle_subscribers [oracle_id ] = oracle_subscriber
214+ self .oracle_id_to_pubkey [oracle_id ] = oracle_info .pubkey
217215
218216 def is_subscribed (self ):
219217 return (
@@ -232,14 +230,12 @@ async def fetch(self):
232230 for spot_market_subscriber in self .spot_market_subscribers .values ():
233231 tasks .append (spot_market_subscriber .fetch ())
234232
235- for oracle_subscriber in self .oracle_subscribers .values ():
236- tasks .append (oracle_subscriber .fetch ())
233+ tasks .append (self .oracle_subscriber .fetch ())
237234
238235 await asyncio .gather (* tasks )
239236
240237 async def add_oracle (self , oracle_info : OracleInfo ):
241- oracle_id = get_oracle_id (oracle_info .pubkey , oracle_info .source )
242- if oracle_id in self .oracle_subscribers :
238+ if oracle_info .pubkey in self .oracle_subscriber .data_map :
243239 return True
244240
245241 if oracle_info .pubkey == Pubkey .default ():
@@ -271,7 +267,10 @@ def get_spot_market_and_slot(
271267 def get_oracle_price_data_and_slot (
272268 self , oracle_id : str
273269 ) -> Optional [DataAndSlot [OraclePriceData ]]:
274- return self .oracle_subscribers [oracle_id ].data_and_slot
270+ pubkey = self .oracle_id_to_pubkey .get (oracle_id )
271+ if pubkey is None :
272+ return None
273+ return self .oracle_subscriber .get_data (pubkey )
275274
276275 async def unsubscribe (self ):
277276 if self .is_subscribed ():
@@ -284,8 +283,7 @@ async def unsubscribe(self):
284283 await spot_market_subscriber .unsubscribe ()
285284 for perp_market_subscriber in self .perp_market_subscribers .values ():
286285 await perp_market_subscriber .unsubscribe ()
287- for oracle_subscriber in self .oracle_subscribers .values ():
288- await oracle_subscriber .unsubscribe ()
286+ await self .oracle_subscriber .unsubscribe ()
289287
290288 def get_market_accounts_and_slots (self ) -> list [DataAndSlot [PerpMarketAccount ]]:
291289 if self .perp_market_map :
@@ -317,7 +315,7 @@ async def _set_perp_oracle_map(self):
317315 market_index = perp_market_account .market_index
318316 oracle = perp_market_account .amm .oracle
319317 oracle_id = get_oracle_id (oracle , perp_market_account .amm .oracle_source )
320- if oracle_id not in self .oracle_subscribers :
318+ if oracle not in self .oracle_subscriber . data_map :
321319 await self .add_oracle (
322320 OracleInfo (oracle , perp_market_account .amm .oracle_source )
323321 )
@@ -334,7 +332,7 @@ async def _set_spot_oracle_map(self):
334332 market_index = spot_market_account .market_index
335333 oracle = spot_market_account .oracle
336334 oracle_id = get_oracle_id (oracle , spot_market_account .oracle_source )
337- if oracle_id not in self .oracle_subscribers :
335+ if oracle not in self .oracle_subscriber . data_map :
338336 await self .add_oracle (
339337 OracleInfo (oracle , spot_market_account .oracle_source )
340338 )
0 commit comments