Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions examples/spot_market_trade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import asyncio
import logging
import os

from anchorpy.provider import Provider, Wallet
from dotenv import load_dotenv
from solana.rpc.async_api import AsyncClient

from driftpy.constants.spot_markets import mainnet_spot_market_configs
from driftpy.drift_client import DriftClient
from driftpy.keypair import load_keypair
from driftpy.types import TxParams

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

load_dotenv()


def get_market_by_symbol(symbol: str):
for market in mainnet_spot_market_configs:
if market.symbol == symbol:
return market
raise Exception(f"Market {symbol} not found")


async def make_spot_trade():
rpc = os.environ.get("RPC_TRITON")
secret = os.environ.get("PRIVATE_KEY")
kp = load_keypair(secret)
wallet = Wallet(kp)
logger.info(f"Using wallet: {wallet.public_key}")

connection = AsyncClient(rpc)
provider = Provider(connection, wallet)
drift_client = DriftClient(
provider.connection,
provider.wallet,
"mainnet",
tx_params=TxParams(
compute_units_price=85_000,
compute_units=1_000_000,
),
)
await drift_client.subscribe()
logger.info("Drift client subscribed")

in_decimals_result = drift_client.get_spot_market_account(
get_market_by_symbol("USDS").market_index
)
if not in_decimals_result:
logger.error("USDS market not found")
raise Exception("Market not found")

in_decimals = in_decimals_result.decimals
logger.info(f"USDS decimals: {in_decimals}")

swap_amount = int(1 * 10**in_decimals)
logger.info(f"Swapping {swap_amount} USDS to USDC")

try:
swap_ixs, swap_lookups = await drift_client.get_jupiter_swap_ix_v6(
out_market_idx=get_market_by_symbol("USDC").market_index,
in_market_idx=get_market_by_symbol("USDS").market_index,
amount=swap_amount,
swap_mode="ExactIn",
only_direct_routes=True,
)
logger.info("Got swap instructions")
print("[DEBUG] Got swap instructions of length", len(swap_ixs))

await drift_client.send_ixs(
ixs=swap_ixs,
lookup_tables=swap_lookups,
)
logger.info("Swap complete")
except Exception as e:
logger.error(f"Error during swap: {e}")
raise e
finally:
await drift_client.unsubscribe()
await connection.close()


if __name__ == "__main__":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(make_spot_trade())
finally:
pending = asyncio.all_tasks(loop)
loop.run_until_complete(asyncio.gather(*pending))
loop.close()
56 changes: 35 additions & 21 deletions src/driftpy/drift_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import base64
import json
import os
import random
import string
Expand Down Expand Up @@ -3155,37 +3154,53 @@ async def get_jupiter_swap_ix_v6(
amount: int,
out_ata: Optional[Pubkey] = None,
in_ata: Optional[Pubkey] = None,
slippage_bps: Optional[int] = None,
quote=None,
slippage_bps: int = 50,
quote: Optional[dict] = None,
reduce_only: Optional[SwapReduceOnly] = None,
user_account_public_key: Optional[Pubkey] = None,
swap_mode: str = "ExactIn",
fee_account: Optional[Pubkey] = None,
platform_fee_bps: Optional[int] = None,
only_direct_routes: bool = False,
) -> Tuple[list[Instruction], list[AddressLookupTableAccount]]:
pre_instructions: list[Instruction] = []
JUPITER_URL = os.getenv("JUPITER_URL", "https://quote-api.jup.ag/v6")

out_market = self.get_spot_market_account(out_market_idx)
in_market = self.get_spot_market_account(in_market_idx)

if slippage_bps is None:
slippage_bps = 10
if not out_market or not in_market:
raise Exception("Invalid market indexes")

if quote is None:
url = f"{JUPITER_URL}/quote?inputMint={str(in_market.mint)}&outputMint={str(out_market.mint)}&amount={amount}&slippageBps={slippage_bps}"

params = {
"inputMint": str(in_market.mint),
"outputMint": str(out_market.mint),
"amount": str(amount),
"slippageBps": slippage_bps,
"swapMode": swap_mode,
"maxAccounts": 50,
}
if only_direct_routes:
params["onlyDirectRoutes"] = "true"
if platform_fee_bps:
params["platformFeeBps"] = platform_fee_bps

url = f"{JUPITER_URL}/quote?" + "&".join(
f"{k}={v}" for k, v in params.items()
)
quote_resp = requests.get(url)

if quote_resp.status_code != 200:
raise Exception("Couldn't get a Jupiter quote")
raise Exception(f"Jupiter quote failed: {quote_resp.text}")

quote = quote_resp.json()

if out_ata is None:
out_ata: Pubkey = self.get_associated_token_account_public_key(
out_ata = self.get_associated_token_account_public_key(
out_market.market_index
)

ai = await self.connection.get_account_info(out_ata)

if not ai.value:
pre_instructions.append(
self.create_associated_token_account_idempotent_instruction(
Expand All @@ -3197,12 +3212,10 @@ async def get_jupiter_swap_ix_v6(
)

if in_ata is None:
in_ata: Pubkey = self.get_associated_token_account_public_key(
in_ata = self.get_associated_token_account_public_key(
in_market.market_index
)

ai = await self.connection.get_account_info(in_ata)

if not ai.value:
pre_instructions.append(
self.create_associated_token_account_idempotent_instruction(
Expand All @@ -3213,23 +3226,24 @@ async def get_jupiter_swap_ix_v6(
)
)

data = {
swap_data = {
"quoteResponse": quote,
"userPublicKey": str(self.wallet.public_key),
"destinationTokenAccount": str(out_ata),
}
if fee_account:
swap_data["feeAccount"] = str(fee_account)

swap_ix_resp = requests.post(
f"{JUPITER_URL}/swap-instructions",
headers={"Accept": "application/json", "Content-Type": "application/json"},
data=json.dumps(data),
json=swap_data,
)

if swap_ix_resp.status_code != 200:
raise Exception("Couldn't get Jupiter swap ix")
raise Exception(f"Jupiter swap instructions failed: {swap_ix_resp.text}")

swap_ix_json = swap_ix_resp.json()

swap_ix = swap_ix_json.get("swapInstruction")
address_table_lookups = swap_ix_json.get("addressLookupTableAddresses")

Expand Down Expand Up @@ -3258,11 +3272,11 @@ async def get_jupiter_swap_ix_v6(
cleansed_ixs: list[Instruction] = []

for ix in ixs:
if type(ix) == list:
if isinstance(ix, list):
for i in ix:
if type(i) == dict:
if isinstance(i, dict):
cleansed_ixs.append(self._dict_to_instructions(i))
elif type(ix) == dict:
elif isinstance(ix, dict):
cleansed_ixs.append(self._dict_to_instructions(ix))
else:
cleansed_ixs.append(ix)
Expand Down
2 changes: 2 additions & 0 deletions src/driftpy/drift_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,8 @@ def get_unrealized_funding_pnl(
perp_market = self.drift_client.get_perp_market_account(
position.market_index
)
if not perp_market:
raise Exception("Perp market account not found")

unrealized_pnl += calculate_position_funding_pnl(perp_market, position)

Expand Down