diff --git a/examples/minimal_maker.py b/examples/minimal_maker.py new file mode 100644 index 00000000..06e7774a --- /dev/null +++ b/examples/minimal_maker.py @@ -0,0 +1,159 @@ +import asyncio +import logging +import os + +import dotenv +from anchorpy.provider import Wallet +from solana.rpc.async_api import AsyncClient +from solana.rpc.core import RPCException + +from driftpy.accounts.get_accounts import get_protected_maker_mode_stats +from driftpy.constants.numeric_constants import BASE_PRECISION, PRICE_PRECISION +from driftpy.constants.perp_markets import mainnet_perp_market_configs +from driftpy.drift_client import DriftClient +from driftpy.keypair import load_keypair +from driftpy.math.user_status import is_user_protected_maker +from driftpy.types import ( + MarketType, + OrderParams, + OrderType, + PositionDirection, + TxParams, +) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s") +logger = logging.getLogger(__name__) + + +async def get_drift_client() -> DriftClient: + dotenv.load_dotenv() + rpc_url = os.getenv("RPC_TRITON") + private_key = os.getenv("PRIVATE_KEY") + if not rpc_url or not private_key: + raise Exception("Missing env vars") + kp = load_keypair(private_key) + drift_client = DriftClient( + connection=AsyncClient(rpc_url), + wallet=Wallet(kp), + env="mainnet", + tx_params=TxParams(700_000, 10_000), + ) + await drift_client.subscribe() + logger.info("Drift client subscribed") + + user = drift_client.get_user() + is_protected_maker = is_user_protected_maker(user.get_user_account()) + if not is_protected_maker: + logger.warning("User is not a protected maker") + logger.warning("Attempting to make protected maker...") + stats = await get_protected_maker_mode_stats(drift_client.program) + logger.info(f"Protected maker stats: {stats}") + if stats["current_users"] >= stats["max_users"]: + logger.error("No room for a new protected maker") + print("---\nYour orders will not be protected. Continue anyway? (Y/n)") + if input().lower().startswith("n"): + exit(1) + return drift_client + + try: + result = await drift_client.update_user_protected_maker_orders(0, True) + logger.info(result) + except RPCException as e: + logger.error(f"Failed to make protected maker: {e}") + print("---\nYour orders will not be protected. Continue anyway? (Y/n)") + if input().lower().startswith("n"): + exit(1) + + logger.info("Drift client is ready.") + + return drift_client + + +class OracleMaker: + def __init__(self, drift_client: DriftClient, targets: dict[str, dict[str, float]]): + self.client = drift_client + self.last_positions: dict[int, float] = {} + self.target_positions = { + symbol_to_market_index(k): v for k, v in targets.items() + } + logger.info(f"OracleMaker initialized with targets: {self.target_positions}") + + def get_orders_for_market(self, market_index: int, spread: float): + pos = self.client.get_perp_position(market_index) + current_pos = pos.base_asset_amount / BASE_PRECISION if pos else 0.0 + target_position = self.target_positions[market_index]["target"] or 0.0 + pos_diff = abs(current_pos - target_position) + t = min(1.0, pos_diff / 2.0) + spread *= 1.0 + t # range: [spread..2*spread] + + base_size = int(self.target_positions[market_index]["size"] * BASE_PRECISION) + offset_iu = int(spread * PRICE_PRECISION) + logger.info(f"Market={market_index}, current_pos={current_pos:.4f}") + logger.info(f"Spread: {spread:.5f} base_size: {base_size} offset: {offset_iu}") + + bid = OrderParams( + order_type=OrderType.Oracle(), # type: ignore + market_type=MarketType.Perp(), # type: ignore + direction=PositionDirection.Long(), # type: ignore + base_asset_amount=base_size, + market_index=market_index, + oracle_price_offset=-offset_iu, + ) + ask = OrderParams( + order_type=OrderType.Oracle(), # type: ignore + market_type=MarketType.Perp(), # type: ignore + direction=PositionDirection.Short(), # type: ignore + base_asset_amount=base_size, + market_index=market_index, + oracle_price_offset=offset_iu, + ) + + if current_pos > target_position: + logger.info("Skipping bid order - position above target") + bid = None + if current_pos < -target_position: + logger.info("Skipping ask order - position below -target") + ask = None + + orders = [o for o in [bid, ask] if o] + name = mainnet_perp_market_configs[market_index].symbol + logger.info(f"Market {name}: Will place {len(orders)} orders") + return orders + + async def place_orders_for_all_markets(self, spread: float): + all_orders = [] + for m_idx in self.target_positions.keys(): + all_orders.extend(self.get_orders_for_market(m_idx, spread)) + + await self.client.cancel_and_place_orders( + cancel_params=(None, None, None), place_order_params=all_orders + ) + + +def symbol_to_market_index(symbol): + return next( + m.market_index for m in mainnet_perp_market_configs if m.symbol == symbol + ) + + +async def main(): + logger.info("Starting OracleMaker") + drift_client = await get_drift_client() + target_sizes_map = { + "SOL-PERP": {"target": 2.0, "size": 2.0}, + "DRIFT-PERP": {"target": 20.0, "size": 20.0}, + } # add as many market indexes as you want to make for + maker = OracleMaker(drift_client, targets=target_sizes_map) + try: + while True: + await maker.place_orders_for_all_markets(spread=0.008) + await asyncio.sleep(10) + except KeyboardInterrupt: + logger.info("Interrupted by user. Exiting loop...") + finally: + await maker.client.unsubscribe() + logger.info("Unsubscribed from Drift client.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/driftpy/accounts/get_accounts.py b/src/driftpy/accounts/get_accounts.py index 35a1e28f..bc330ed7 100644 --- a/src/driftpy/accounts/get_accounts.py +++ b/src/driftpy/accounts/get_accounts.py @@ -1,17 +1,33 @@ -from typing import cast, Optional, Callable -from solders.pubkey import Pubkey -from anchorpy import Program, ProgramAccount +from typing import Callable, Optional, cast + +from anchorpy.program.core import Program +from anchorpy.program.namespace.account import ProgramAccount from solana.rpc.commitment import Commitment +from solders.pubkey import Pubkey -from driftpy.types import * -from driftpy.addresses import * -from .types import DataAndSlot, T +from driftpy.accounts.types import DataAndSlot, T +from driftpy.addresses import ( + get_insurance_fund_stake_public_key, + get_perp_market_public_key, + get_protected_maker_mode_config_public_key, + get_spot_market_public_key, + get_state_public_key, + get_user_stats_account_public_key, +) +from driftpy.types import ( + InsuranceFundStakeAccount, + PerpMarketAccount, + SpotMarketAccount, + StateAccount, + UserAccount, + UserStatsAccount, +) async def get_account_data_and_slot( address: Pubkey, program: Program, - commitment: Commitment = "processed", + commitment: Commitment = Commitment("processed"), decode: Optional[Callable[[bytes], T]] = None, ) -> Optional[DataAndSlot[T]]: account_info = await program.provider.connection.get_account_info( @@ -67,14 +83,14 @@ async def get_user_stats_account( async def get_user_account_and_slot( program: Program, user_public_key: Pubkey, -) -> DataAndSlot[UserAccount]: +) -> Optional[DataAndSlot[UserAccount]]: return await get_account_data_and_slot(user_public_key, program) async def get_user_account( program: Program, user_public_key: Pubkey, -) -> UserAccount: +) -> Optional[UserAccount]: return (await get_user_account_and_slot(program, user_public_key)).data @@ -89,7 +105,7 @@ async def get_perp_market_account_and_slot( async def get_perp_market_account( program: Program, market_index: int -) -> PerpMarketAccount: +) -> Optional[PerpMarketAccount]: return (await get_perp_market_account_and_slot(program, market_index)).data @@ -99,7 +115,7 @@ async def get_all_perp_market_accounts(program: Program) -> list[ProgramAccount] async def get_spot_market_account_and_slot( program: Program, spot_market_index: int -) -> DataAndSlot[SpotMarketAccount]: +) -> Optional[DataAndSlot[SpotMarketAccount]]: spot_market_public_key = get_spot_market_public_key( program.program_id, spot_market_index ) @@ -108,9 +124,19 @@ async def get_spot_market_account_and_slot( async def get_spot_market_account( program: Program, spot_market_index: int -) -> SpotMarketAccount: +) -> Optional[SpotMarketAccount]: return (await get_spot_market_account_and_slot(program, spot_market_index)).data async def get_all_spot_market_accounts(program: Program) -> list[ProgramAccount]: return await program.account["SpotMarket"].all() + + +async def get_protected_maker_mode_stats(program: Program) -> dict[str, int | bool]: + config_pubkey = get_protected_maker_mode_config_public_key(program.program_id) + config = await program.account["ProtectedMakerModeConfig"].fetch(config_pubkey) + return { + "max_users": config.max_users, + "current_users": config.current_users, + "is_reduce_only": config.reduce_only > 0, + } diff --git a/src/driftpy/constants/perp_markets.py b/src/driftpy/constants/perp_markets.py index 0be6d2eb..9a02f5b6 100644 --- a/src/driftpy/constants/perp_markets.py +++ b/src/driftpy/constants/perp_markets.py @@ -663,4 +663,11 @@ class PerpMarketConfig: oracle=Pubkey.from_string("7vGHChuBJyFMYBqMLXRzBmRxWdSuwEmg8RvRm3RWQsxi"), oracle_source=OracleSource.PythPull(), # type: ignore ), + PerpMarketConfig( + symbol="AI16Z-PERP", + base_asset_symbol="AI16Z", + market_index=63, + oracle=Pubkey.from_string("3gdGkrmBdYR7B1MRRdRVysqhZCvYvLGHonr9b7o9WVki"), + oracle_source=OracleSource.PythPull(), # type: ignore + ), ] diff --git a/src/driftpy/constants/spot_markets.py b/src/driftpy/constants/spot_markets.py index e18640b8..3e0d36bb 100644 --- a/src/driftpy/constants/spot_markets.py +++ b/src/driftpy/constants/spot_markets.py @@ -321,4 +321,11 @@ class SpotMarketConfig: oracle_source=OracleSource.PythStableCoinPull(), # type: ignore mint=Pubkey.from_string("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"), ), + SpotMarketConfig( + symbol="AI16Z", + market_index=35, + oracle=Pubkey.from_string("3gdGkrmBdYR7B1MRRdRVysqhZCvYvLGHonr9b7o9WVki"), + oracle_source=OracleSource.PythPull(), # type: ignore + mint=Pubkey.from_string("HeLp6NuQkmYB4pYWo2zYs22mESHXPQYzXbB8n4V98jwC"), + ), ] diff --git a/src/driftpy/drift_client.py b/src/driftpy/drift_client.py index 5402d9de..ae5f0112 100644 --- a/src/driftpy/drift_client.py +++ b/src/driftpy/drift_client.py @@ -39,12 +39,10 @@ from driftpy.account_subscription_config import AccountSubscriptionConfig from driftpy.accounts import ( DataAndSlot, - OracleInfo, OraclePriceData, PerpMarketAccount, SpotMarketAccount, StateAccount, - TxParams, UserAccount, ) from driftpy.accounts.cache.drift_client import CachedDriftClientAccountSubscriber @@ -87,6 +85,7 @@ MakerInfo, MarketType, ModifyOrderParams, + OracleInfo, Order, OrderParams, OrderType, @@ -98,6 +97,7 @@ SerumV3FulfillmentConfigAccount, SpotPosition, SwapReduceOnly, + TxParams, is_variant, ) diff --git a/src/driftpy/drift_user.py b/src/driftpy/drift_user.py index ac6b91ee..a0483845 100644 --- a/src/driftpy/drift_user.py +++ b/src/driftpy/drift_user.py @@ -21,7 +21,13 @@ is_spot_position_available, ) from driftpy.oracles.strict_oracle_price import StrictOraclePrice -from driftpy.types import OraclePriceData +from driftpy.types import ( + OraclePriceData, + Order, + PerpPosition, + SpotPosition, + UserAccount, +) class DriftUser: diff --git a/src/driftpy/math/margin.py b/src/driftpy/math/margin.py index e65947cf..084835e9 100644 --- a/src/driftpy/math/margin.py +++ b/src/driftpy/math/margin.py @@ -1,11 +1,21 @@ import math - -from driftpy.math.spot_market import * - from enum import Enum -from driftpy.types import OraclePriceData -from driftpy.constants.numeric_constants import * +from driftpy.constants.numeric_constants import ( + AMM_RESERVE_PRECISION, + BASE_PRECISION, + MARGIN_PRECISION, + PRICE_TO_QUOTE_PRECISION_RATIO, + SPOT_IMF_PRECISION, + SPOT_WEIGHT_PRECISION, +) +from driftpy.math.spot_market import get_token_amount, get_token_value +from driftpy.types import ( + OraclePriceData, + PerpMarketAccount, + SpotBalanceType, + SpotMarketAccount, +) def calculate_size_discount_asset_weight( @@ -77,7 +87,9 @@ def calculate_scaled_initial_asset_weight( return spot_market.initial_asset_weight deposits = get_token_amount( - spot_market.deposit_balance, spot_market, SpotBalanceType.Deposit() + spot_market.deposit_balance, + spot_market, + SpotBalanceType.Deposit(), # type: ignore ) deposits_value = get_token_value(deposits, spot_market.decimals, oracle_price) @@ -139,7 +151,9 @@ def calculate_net_user_pnl_imbalance( user_pnl = calculate_net_user_pnl(perp_market, oracle_data) pnl_pool = get_token_amount( - perp_market.pnl_pool.scaled_balance, spot_market, SpotBalanceType.Deposit() + perp_market.pnl_pool.scaled_balance, + spot_market, + SpotBalanceType.Deposit(), # type: ignore ) imbalance = user_pnl - pnl_pool diff --git a/src/driftpy/math/perp_position.py b/src/driftpy/math/perp_position.py index 3cfcd8c8..57d49b63 100644 --- a/src/driftpy/math/perp_position.py +++ b/src/driftpy/math/perp_position.py @@ -1,7 +1,21 @@ -from driftpy.math.spot_market import * -from driftpy.types import OraclePriceData, is_variant -from driftpy.constants.numeric_constants import * +from driftpy.constants.numeric_constants import ( + AMM_RESERVE_PRECISION, + AMM_TIMES_PEG_TO_QUOTE_PRECISION_RATIO, + AMM_TO_QUOTE_PRECISION_RATIO, + BASE_PRECISION, + FUNDING_RATE_BUFFER, + MAX_PREDICTION_PRICE, + PRICE_PRECISION, +) from driftpy.math.amm import calculate_amm_reserves_after_swap, get_swap_direction +from driftpy.types import ( + AssetType, + OraclePriceData, + PerpMarketAccount, + PerpPosition, + PositionDirection, + is_variant, +) def calculate_base_asset_value_with_oracle( @@ -132,25 +146,28 @@ def is_available(position: PerpPosition): def calculate_base_asset_value( market: PerpMarketAccount, user_position: PerpPosition -) -> int: +) -> float: if user_position.base_asset_amount == 0: return 0 direction_to_close = ( - PositionDirection.Short() + PositionDirection.Short() # type: ignore if user_position.base_asset_amount > 0 - else PositionDirection.Long() + else PositionDirection.Long() # type: ignore ) new_quote_asset_reserve, _ = calculate_amm_reserves_after_swap( market.amm, - AssetType.BASE(), + AssetType.BASE(), # type: ignore abs(user_position.base_asset_amount), - get_swap_direction(AssetType.BASE(), direction_to_close), + get_swap_direction( + AssetType.BASE(), # type: ignore + direction_to_close, + ), ) result = None - if direction_to_close == PositionDirection.Short(): + if direction_to_close == PositionDirection.Short(): # type: ignore result = ( (market.amm.quote_asset_reserve - new_quote_asset_reserve) * market.amm.peg_multiplier diff --git a/src/driftpy/math/spot_market.py b/src/driftpy/math/spot_market.py index 3419e3cc..6b159b8d 100644 --- a/src/driftpy/math/spot_market.py +++ b/src/driftpy/math/spot_market.py @@ -1,8 +1,12 @@ from typing import Union -from driftpy.accounts import * from driftpy.math.utils import div_ceil -from driftpy.types import OraclePriceData +from driftpy.types import ( + OraclePriceData, + SpotBalanceType, + SpotMarketAccount, + is_variant, +) def get_signed_token_amount(amount, balance_type):