Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions examples/minimal_maker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import asyncio
import logging
import os

import dotenv
from anchorpy.provider import Wallet
from solana.rpc.async_api import AsyncClient
from solana.rpc.core import RPCException

from driftpy.accounts.get_accounts import get_protected_maker_mode_stats
from driftpy.constants.numeric_constants import BASE_PRECISION, PRICE_PRECISION
from driftpy.constants.perp_markets import mainnet_perp_market_configs
from driftpy.drift_client import DriftClient
from driftpy.keypair import load_keypair
from driftpy.math.user_status import is_user_protected_maker
from driftpy.types import (
MarketType,
OrderParams,
OrderType,
PositionDirection,
TxParams,
)

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
logger = logging.getLogger(__name__)


async def get_drift_client() -> DriftClient:
dotenv.load_dotenv()
rpc_url = os.getenv("RPC_TRITON")
private_key = os.getenv("PRIVATE_KEY")
if not rpc_url or not private_key:
raise Exception("Missing env vars")
kp = load_keypair(private_key)
drift_client = DriftClient(
connection=AsyncClient(rpc_url),
wallet=Wallet(kp),
env="mainnet",
tx_params=TxParams(700_000, 10_000),
)
await drift_client.subscribe()
logger.info("Drift client subscribed")

user = drift_client.get_user()
is_protected_maker = is_user_protected_maker(user.get_user_account())
if not is_protected_maker:
logger.warning("User is not a protected maker")
logger.warning("Attempting to make protected maker...")
stats = await get_protected_maker_mode_stats(drift_client.program)
logger.info(f"Protected maker stats: {stats}")
if stats["current_users"] >= stats["max_users"]:
logger.error("No room for a new protected maker")
print("---\nYour orders will not be protected. Continue anyway? (Y/n)")
if input().lower().startswith("n"):
exit(1)
return drift_client

try:
result = await drift_client.update_user_protected_maker_orders(0, True)
logger.info(result)
except RPCException as e:
logger.error(f"Failed to make protected maker: {e}")
print("---\nYour orders will not be protected. Continue anyway? (Y/n)")
if input().lower().startswith("n"):
exit(1)

logger.info("Drift client is ready.")

return drift_client


class OracleMaker:
def __init__(self, drift_client: DriftClient, targets: dict[str, dict[str, float]]):
self.client = drift_client
self.last_positions: dict[int, float] = {}
self.target_positions = {
symbol_to_market_index(k): v for k, v in targets.items()
}
logger.info(f"OracleMaker initialized with targets: {self.target_positions}")

def get_orders_for_market(self, market_index: int, spread: float):
pos = self.client.get_perp_position(market_index)
current_pos = pos.base_asset_amount / BASE_PRECISION if pos else 0.0
target_position = self.target_positions[market_index]["target"] or 0.0
pos_diff = abs(current_pos - target_position)
t = min(1.0, pos_diff / 2.0)
spread *= 1.0 + t # range: [spread..2*spread]

base_size = int(self.target_positions[market_index]["size"] * BASE_PRECISION)
offset_iu = int(spread * PRICE_PRECISION)
logger.info(f"Market={market_index}, current_pos={current_pos:.4f}")
logger.info(f"Spread: {spread:.5f} base_size: {base_size} offset: {offset_iu}")

bid = OrderParams(
order_type=OrderType.Oracle(), # type: ignore
market_type=MarketType.Perp(), # type: ignore
direction=PositionDirection.Long(), # type: ignore
base_asset_amount=base_size,
market_index=market_index,
oracle_price_offset=-offset_iu,
)
ask = OrderParams(
order_type=OrderType.Oracle(), # type: ignore
market_type=MarketType.Perp(), # type: ignore
direction=PositionDirection.Short(), # type: ignore
base_asset_amount=base_size,
market_index=market_index,
oracle_price_offset=offset_iu,
)

if current_pos > target_position:
logger.info("Skipping bid order - position above target")
bid = None
if current_pos < -target_position:
logger.info("Skipping ask order - position below -target")
ask = None

orders = [o for o in [bid, ask] if o]
name = mainnet_perp_market_configs[market_index].symbol
logger.info(f"Market {name}: Will place {len(orders)} orders")
return orders

async def place_orders_for_all_markets(self, spread: float):
all_orders = []
for m_idx in self.target_positions.keys():
all_orders.extend(self.get_orders_for_market(m_idx, spread))

await self.client.cancel_and_place_orders(
cancel_params=(None, None, None), place_order_params=all_orders
)


def symbol_to_market_index(symbol):
return next(
m.market_index for m in mainnet_perp_market_configs if m.symbol == symbol
)


async def main():
logger.info("Starting OracleMaker")
drift_client = await get_drift_client()
target_sizes_map = {
"SOL-PERP": {"target": 2.0, "size": 2.0},
"DRIFT-PERP": {"target": 20.0, "size": 20.0},
} # add as many market indexes as you want to make for
maker = OracleMaker(drift_client, targets=target_sizes_map)
try:
while True:
await maker.place_orders_for_all_markets(spread=0.008)
await asyncio.sleep(10)
except KeyboardInterrupt:
logger.info("Interrupted by user. Exiting loop...")
finally:
await maker.client.unsubscribe()
logger.info("Unsubscribed from Drift client.")


if __name__ == "__main__":
asyncio.run(main())
50 changes: 38 additions & 12 deletions src/driftpy/accounts/get_accounts.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
from typing import cast, Optional, Callable
from solders.pubkey import Pubkey
from anchorpy import Program, ProgramAccount
from typing import Callable, Optional, cast

from anchorpy.program.core import Program
from anchorpy.program.namespace.account import ProgramAccount
from solana.rpc.commitment import Commitment
from solders.pubkey import Pubkey

from driftpy.types import *
from driftpy.addresses import *
from .types import DataAndSlot, T
from driftpy.accounts.types import DataAndSlot, T
from driftpy.addresses import (
get_insurance_fund_stake_public_key,
get_perp_market_public_key,
get_protected_maker_mode_config_public_key,
get_spot_market_public_key,
get_state_public_key,
get_user_stats_account_public_key,
)
from driftpy.types import (
InsuranceFundStakeAccount,
PerpMarketAccount,
SpotMarketAccount,
StateAccount,
UserAccount,
UserStatsAccount,
)


async def get_account_data_and_slot(
address: Pubkey,
program: Program,
commitment: Commitment = "processed",
commitment: Commitment = Commitment("processed"),
decode: Optional[Callable[[bytes], T]] = None,
) -> Optional[DataAndSlot[T]]:
account_info = await program.provider.connection.get_account_info(
Expand Down Expand Up @@ -67,14 +83,14 @@ async def get_user_stats_account(
async def get_user_account_and_slot(
program: Program,
user_public_key: Pubkey,
) -> DataAndSlot[UserAccount]:
) -> Optional[DataAndSlot[UserAccount]]:
return await get_account_data_and_slot(user_public_key, program)


async def get_user_account(
program: Program,
user_public_key: Pubkey,
) -> UserAccount:
) -> Optional[UserAccount]:
return (await get_user_account_and_slot(program, user_public_key)).data


Expand All @@ -89,7 +105,7 @@ async def get_perp_market_account_and_slot(

async def get_perp_market_account(
program: Program, market_index: int
) -> PerpMarketAccount:
) -> Optional[PerpMarketAccount]:
return (await get_perp_market_account_and_slot(program, market_index)).data


Expand All @@ -99,7 +115,7 @@ async def get_all_perp_market_accounts(program: Program) -> list[ProgramAccount]

async def get_spot_market_account_and_slot(
program: Program, spot_market_index: int
) -> DataAndSlot[SpotMarketAccount]:
) -> Optional[DataAndSlot[SpotMarketAccount]]:
spot_market_public_key = get_spot_market_public_key(
program.program_id, spot_market_index
)
Expand All @@ -108,9 +124,19 @@ async def get_spot_market_account_and_slot(

async def get_spot_market_account(
program: Program, spot_market_index: int
) -> SpotMarketAccount:
) -> Optional[SpotMarketAccount]:
return (await get_spot_market_account_and_slot(program, spot_market_index)).data


async def get_all_spot_market_accounts(program: Program) -> list[ProgramAccount]:
return await program.account["SpotMarket"].all()


async def get_protected_maker_mode_stats(program: Program) -> dict[str, int | bool]:
config_pubkey = get_protected_maker_mode_config_public_key(program.program_id)
config = await program.account["ProtectedMakerModeConfig"].fetch(config_pubkey)
return {
"max_users": config.max_users,
"current_users": config.current_users,
"is_reduce_only": config.reduce_only > 0,
}
7 changes: 7 additions & 0 deletions src/driftpy/constants/perp_markets.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,4 +663,11 @@ class PerpMarketConfig:
oracle=Pubkey.from_string("7vGHChuBJyFMYBqMLXRzBmRxWdSuwEmg8RvRm3RWQsxi"),
oracle_source=OracleSource.PythPull(), # type: ignore
),
PerpMarketConfig(
symbol="AI16Z-PERP",
base_asset_symbol="AI16Z",
market_index=63,
oracle=Pubkey.from_string("3gdGkrmBdYR7B1MRRdRVysqhZCvYvLGHonr9b7o9WVki"),
oracle_source=OracleSource.PythPull(), # type: ignore
),
]
7 changes: 7 additions & 0 deletions src/driftpy/constants/spot_markets.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,4 +321,11 @@ class SpotMarketConfig:
oracle_source=OracleSource.PythStableCoinPull(), # type: ignore
mint=Pubkey.from_string("EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v"),
),
SpotMarketConfig(
symbol="AI16Z",
market_index=35,
oracle=Pubkey.from_string("3gdGkrmBdYR7B1MRRdRVysqhZCvYvLGHonr9b7o9WVki"),
oracle_source=OracleSource.PythPull(), # type: ignore
mint=Pubkey.from_string("HeLp6NuQkmYB4pYWo2zYs22mESHXPQYzXbB8n4V98jwC"),
),
]
4 changes: 2 additions & 2 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,10 @@
from driftpy.account_subscription_config import AccountSubscriptionConfig
from driftpy.accounts import (
DataAndSlot,
OracleInfo,
OraclePriceData,
PerpMarketAccount,
SpotMarketAccount,
StateAccount,
TxParams,
UserAccount,
)
from driftpy.accounts.cache.drift_client import CachedDriftClientAccountSubscriber
Expand Down Expand Up @@ -87,6 +85,7 @@
MakerInfo,
MarketType,
ModifyOrderParams,
OracleInfo,
Order,
OrderParams,
OrderType,
Expand All @@ -98,6 +97,7 @@
SerumV3FulfillmentConfigAccount,
SpotPosition,
SwapReduceOnly,
TxParams,
is_variant,
)

Expand Down
8 changes: 7 additions & 1 deletion src/driftpy/drift_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
is_spot_position_available,
)
from driftpy.oracles.strict_oracle_price import StrictOraclePrice
from driftpy.types import OraclePriceData
from driftpy.types import (
OraclePriceData,
Order,
PerpPosition,
SpotPosition,
UserAccount,
)


class DriftUser:
Expand Down
28 changes: 21 additions & 7 deletions src/driftpy/math/margin.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import math

from driftpy.math.spot_market import *

from enum import Enum

from driftpy.types import OraclePriceData
from driftpy.constants.numeric_constants import *
from driftpy.constants.numeric_constants import (
AMM_RESERVE_PRECISION,
BASE_PRECISION,
MARGIN_PRECISION,
PRICE_TO_QUOTE_PRECISION_RATIO,
SPOT_IMF_PRECISION,
SPOT_WEIGHT_PRECISION,
)
from driftpy.math.spot_market import get_token_amount, get_token_value
from driftpy.types import (
OraclePriceData,
PerpMarketAccount,
SpotBalanceType,
SpotMarketAccount,
)


def calculate_size_discount_asset_weight(
Expand Down Expand Up @@ -77,7 +87,9 @@ def calculate_scaled_initial_asset_weight(
return spot_market.initial_asset_weight

deposits = get_token_amount(
spot_market.deposit_balance, spot_market, SpotBalanceType.Deposit()
spot_market.deposit_balance,
spot_market,
SpotBalanceType.Deposit(), # type: ignore
)

deposits_value = get_token_value(deposits, spot_market.decimals, oracle_price)
Expand Down Expand Up @@ -139,7 +151,9 @@ def calculate_net_user_pnl_imbalance(
user_pnl = calculate_net_user_pnl(perp_market, oracle_data)

pnl_pool = get_token_amount(
perp_market.pnl_pool.scaled_balance, spot_market, SpotBalanceType.Deposit()
perp_market.pnl_pool.scaled_balance,
spot_market,
SpotBalanceType.Deposit(), # type: ignore
)

imbalance = user_pnl - pnl_pool
Expand Down
Loading
Loading