diff --git a/tests/dlob_test_constants.py b/tests/dlob_test_constants.py index 74531d5b..58cd4459 100644 --- a/tests/dlob_test_constants.py +++ b/tests/dlob_test_constants.py @@ -1,3 +1,5 @@ +from unittest.mock import Mock + from solders.pubkey import Pubkey from driftpy.constants.config import devnet_spot_market_configs @@ -8,7 +10,6 @@ QUOTE_PRECISION, SPOT_CUMULATIVE_INTEREST_PRECISION, SPOT_MARKET_CUMULATIVE_INTEREST_PRECISION, - SPOT_MARKET_CUMULATIVE_INTEREST_PRECISION_EXP, SPOT_MARKET_WEIGHT_PRECISION, ) from driftpy.types import ( @@ -158,6 +159,8 @@ ) # Mock Perp Markets +mock_pubkey = Mock() + mock_perp_markets = [ PerpMarketAccount( status=MarketStatus.Initialized(), @@ -167,7 +170,7 @@ expiry_ts=0, expiry_price=0, market_index=0, - pubkey=Pubkey.default(), + pubkey=mock_pubkey, amm=mock_amm, number_of_users_with_base=0, number_of_users=0, @@ -204,7 +207,7 @@ expiry_ts=0, expiry_price=0, market_index=1, - pubkey=Pubkey.default(), + pubkey=mock_pubkey, amm=mock_amm, number_of_users_with_base=0, number_of_users=0, @@ -241,7 +244,7 @@ expiry_ts=0, expiry_price=0, market_index=2, - pubkey=Pubkey.default(), + pubkey=mock_pubkey, amm=mock_amm, number_of_users_with_base=0, number_of_users=0, diff --git a/tests/math/amm.py b/tests/math/amm.py index 1931c9c0..8533a466 100644 --- a/tests/math/amm.py +++ b/tests/math/amm.py @@ -1,4 +1,6 @@ import copy +from copy import deepcopy +from unittest.mock import Mock import pytest @@ -7,17 +9,134 @@ PEG_PRECISION, PRICE_PRECISION, QUOTE_PRECISION, + SPOT_MARKET_CUMULATIVE_INTEREST_PRECISION, + SPOT_MARKET_WEIGHT_PRECISION, ) from driftpy.dlob.orderbook_levels import get_vamm_l2_generator from driftpy.math.amm import calculate_market_open_bid_ask, calculate_updated_amm -from driftpy.types import OraclePriceData -from tests.dlob_test_constants import mock_perp_markets +from driftpy.types import ( + AssetTier, + MarketStatus, + OraclePriceData, + OracleSource, + PerpMarketAccount, + SpotMarketAccount, +) +from tests.dlob_test_constants import devnet_spot_market_configs, mock_perp_markets + +# Create a mock object for Pubkey +mock_pubkey = Mock() +mock_historical_oracle_data = Mock() +mock_historical_index_data = Mock() +mock_revenue_pool = Mock() +mock_spot_fee_pool = Mock() +mock_insurance_fund = Mock() + +mock_spot_markets = [ + SpotMarketAccount( + status=MarketStatus.Active(), + asset_tier=AssetTier.COLLATERAL, + name=[], + max_token_deposits=1000000 * QUOTE_PRECISION, + market_index=0, + pubkey=mock_pubkey, # Use mock object + mint=devnet_spot_market_configs[0].mint, # Replace with actual mint + vault=mock_pubkey, # Use mock object + oracle=mock_pubkey, # Use mock object + historical_oracle_data=mock_historical_oracle_data, + historical_index_data=mock_historical_index_data, + revenue_pool=mock_revenue_pool, + spot_fee_pool=mock_spot_fee_pool, + insurance_fund=mock_insurance_fund, + total_spot_fee=0, + deposit_balance=0, + borrow_balance=0, + cumulative_deposit_interest=SPOT_MARKET_CUMULATIVE_INTEREST_PRECISION, + cumulative_borrow_interest=SPOT_MARKET_CUMULATIVE_INTEREST_PRECISION, + total_social_loss=0, + total_quote_social_loss=0, + withdraw_guard_threshold=0, + deposit_token_twap=0, + borrow_token_twap=0, + utilization_twap=0, + last_interest_ts=0, + last_twap_ts=0, + expiry_ts=0, + order_step_size=0, + order_tick_size=0, + min_order_size=0, + max_position_size=0, + next_fill_record_id=0, + next_deposit_record_id=0, + initial_asset_weight=SPOT_MARKET_WEIGHT_PRECISION, + maintenance_asset_weight=SPOT_MARKET_WEIGHT_PRECISION, + initial_liability_weight=SPOT_MARKET_WEIGHT_PRECISION, + maintenance_liability_weight=SPOT_MARKET_WEIGHT_PRECISION, + imf_factor=0, + liquidator_fee=0, + if_liquidation_fee=0, + optimal_utilization=0, + optimal_borrow_rate=0, + max_borrow_rate=0, + decimals=6, + orders_enabled=True, + oracle_source=OracleSource.Pyth(), + paused_operations=0, + if_paused_operations=0, + fee_adjustment=0, + max_token_borrows_fraction=0, + ), + # ... other SpotMarketAccount instances ... +] + + +def custom_deepcopy_perp_market_account(perp_market_account): + # Manually copy each attribute, handling Pubkey separately + return PerpMarketAccount( + status=deepcopy(perp_market_account.status), + name=deepcopy(perp_market_account.name), + contract_type=deepcopy(perp_market_account.contract_type), + contract_tier=deepcopy(perp_market_account.contract_tier), + expiry_ts=perp_market_account.expiry_ts, + expiry_price=perp_market_account.expiry_price, + market_index=perp_market_account.market_index, + pubkey=perp_market_account.pubkey, # Directly assign or use a mock + amm=deepcopy(perp_market_account.amm), + number_of_users_with_base=perp_market_account.number_of_users_with_base, + number_of_users=perp_market_account.number_of_users, + margin_ratio_initial=perp_market_account.margin_ratio_initial, + margin_ratio_maintenance=perp_market_account.margin_ratio_maintenance, + next_fill_record_id=perp_market_account.next_fill_record_id, + pnl_pool=deepcopy(perp_market_account.pnl_pool), + if_liquidation_fee=perp_market_account.if_liquidation_fee, + liquidator_fee=perp_market_account.liquidator_fee, + imf_factor=perp_market_account.imf_factor, + next_funding_rate_record_id=perp_market_account.next_funding_rate_record_id, + next_curve_record_id=perp_market_account.next_curve_record_id, + unrealized_pnl_imf_factor=perp_market_account.unrealized_pnl_imf_factor, + unrealized_pnl_max_imbalance=perp_market_account.unrealized_pnl_max_imbalance, + unrealized_pnl_initial_asset_weight=perp_market_account.unrealized_pnl_initial_asset_weight, + unrealized_pnl_maintenance_asset_weight=perp_market_account.unrealized_pnl_maintenance_asset_weight, + insurance_claim=deepcopy(perp_market_account.insurance_claim), + paused_operations=perp_market_account.paused_operations, + quote_spot_market_index=perp_market_account.quote_spot_market_index, + fee_adjustment=perp_market_account.fee_adjustment, + fuel_boost_taker=perp_market_account.fuel_boost_taker, + fuel_boost_maker=perp_market_account.fuel_boost_maker, + fuel_boost_position=perp_market_account.fuel_boost_position, + high_leverage_margin_ratio_initial=perp_market_account.high_leverage_margin_ratio_initial, + high_leverage_margin_ratio_maintenance=perp_market_account.high_leverage_margin_ratio_maintenance, + pool_id=perp_market_account.pool_id, + padding=deepcopy(perp_market_account.padding), + ) @pytest.mark.asyncio async def test_orderbook_l2_gen_no_top_of_book_quote_amounts_10_num_orders_low_liq(): print() - mock_perps = copy.deepcopy(mock_perp_markets) + mock_perps = [ + custom_deepcopy_perp_market_account(market) for market in mock_perp_markets + ] mock_1 = mock_perps[0] cc = 38_104_569 @@ -78,7 +197,9 @@ async def test_orderbook_l2_gen_no_top_of_book_quote_amounts_10_num_orders_low_l @pytest.mark.asyncio async def test_orderbook_l2_gen_no_top_of_book_quote_amounts_10_num_orders(): print() - mock_perps = copy.deepcopy(mock_perp_markets) + mock_perps = [ + custom_deepcopy_perp_market_account(market) for market in mock_perp_markets + ] mock_1 = mock_perps[0] cc = 38_104_569 @@ -131,7 +252,9 @@ async def test_orderbook_l2_gen_no_top_of_book_quote_amounts_10_num_orders(): @pytest.mark.asyncio async def test_orderbook_l2_gen_4_top_of_book_quote_amounts_10_num_orders(): print() - mock_perps = copy.deepcopy(mock_perp_markets) + mock_perps = [ + custom_deepcopy_perp_market_account(market) for market in mock_perp_markets + ] mock_market1 = mock_perps[0] cc = 38_104_569 mock_market1.amm.base_asset_reserve = cc * BASE_PRECISION @@ -195,7 +318,9 @@ async def test_orderbook_l2_gen_4_top_of_book_quote_amounts_10_num_orders(): @pytest.mark.asyncio async def test_orderbook_l2_gen_4_top_quote_amounts_10_orders_low_bid_liquidity(): print() - mock_perps = copy.deepcopy(mock_perp_markets) + mock_perps = [ + custom_deepcopy_perp_market_account(market) for market in mock_perp_markets + ] mock_market1 = mock_perps[0] cc = 38_104_569 mock_market1.amm.base_asset_reserve = cc * BASE_PRECISION @@ -261,7 +386,9 @@ async def test_orderbook_l2_gen_4_top_quote_amounts_10_orders_low_bid_liquidity( @pytest.mark.asyncio async def test_orderbook_l2_gen_4_top_quote_amounts_10_orders_low_ask_liquidity(): print() - mock_perps = copy.deepcopy(mock_perp_markets) + mock_perps = [ + custom_deepcopy_perp_market_account(market) for market in mock_perp_markets + ] mock_market1 = mock_perps[0] cc = 38_104_569 mock_market1.amm.base_asset_reserve = cc * BASE_PRECISION @@ -332,7 +459,9 @@ async def test_orderbook_l2_gen_4_top_quote_amounts_10_orders_low_ask_liquidity( @pytest.mark.asyncio async def test_orderbook_l2_gen_no_top_of_book_quote_amounts_10_orders_no_liquidity(): print() - mock_perps = copy.deepcopy(mock_perp_markets) + mock_perps = [ + custom_deepcopy_perp_market_account(market) for market in mock_perp_markets + ] mock_market1 = mock_perps[0] cc = 38_104_569 mock_market1.amm.base_asset_reserve = cc * BASE_PRECISION diff --git a/tests/math/helpers.py b/tests/math/helpers.py index 2260f89a..a7ae14b2 100644 --- a/tests/math/helpers.py +++ b/tests/math/helpers.py @@ -1,5 +1,6 @@ import asyncio from copy import deepcopy +from unittest.mock import Mock from anchorpy import Wallet from solana.rpc.async_api import AsyncClient @@ -23,6 +24,8 @@ UserAccount, ) +mock_pubkey = Mock() + mock_perp_position = PerpPosition( base_asset_amount=0, last_cumulative_funding_rate=0, @@ -82,8 +85,8 @@ ) mock_user_account = UserAccount( - authority=Pubkey.default(), - delegate=Pubkey.default(), + authority=mock_pubkey, + delegate=mock_pubkey, sub_account_id=0, name=[1], spot_positions=[deepcopy(mock_spot_position) for _ in range(8)], diff --git a/tests/math/spot.py b/tests/math/spot.py index 9d546ddf..d22249d6 100644 --- a/tests/math/spot.py +++ b/tests/math/spot.py @@ -1,16 +1,48 @@ -import math -from pytest import mark from copy import deepcopy +from unittest.mock import Mock -from driftpy.math.margin import calculate_size_premium_liability_weight -from driftpy.constants.numeric_constants import * +from pytest import mark -from tests.dlob_test_constants import mock_spot_markets +from driftpy.constants.numeric_constants import ( + ONE_BILLION, + ONE_HUNDRED_THOUSAND, + ONE_MILLION, + SPOT_CUMULATIVE_INTEREST_PRECISION, + TEN_THOUSAND, +) +from driftpy.math.margin import calculate_size_premium_liability_weight from driftpy.math.spot_balance import ( calculate_borrow_rate, calculate_deposit_rate, calculate_spot_market_borrow_capacity, ) +from tests.dlob_test_constants import mock_spot_markets + +mock_pubkey = Mock() + + +# Create a function to recursively replace all Pubkey objects with mocks +def replace_pubkeys_with_mocks(obj): + from solders.pubkey import Pubkey + + if isinstance(obj, Pubkey): + return mock_pubkey + + if hasattr(obj, "__dict__"): + # For objects with attributes + for attr_name, attr_value in obj.__dict__.items(): + if isinstance(attr_value, Pubkey): + setattr(obj, attr_name, mock_pubkey) + elif hasattr(attr_value, "__dict__"): + # Handle nested objects + replace_pubkeys_with_mocks(attr_value) + + return obj + + +# Apply this to all mock_spot_markets +for market in mock_spot_markets: + replace_pubkeys_with_mocks(market) @mark.asyncio diff --git a/tests/math/user.py b/tests/math/user.py index 353f1b6e..4f3d08cd 100644 --- a/tests/math/user.py +++ b/tests/math/user.py @@ -1,16 +1,50 @@ -from pytest import mark from copy import deepcopy +from unittest.mock import Mock + +from pytest import mark -from driftpy.constants.numeric_constants import * +from driftpy.constants.numeric_constants import ( + BASE_PRECISION, + MARGIN_PRECISION, + PRICE_PRECISION, + QUOTE_PRECISION, + SPOT_BALANCE_PRECISION, + SPOT_CUMULATIVE_INTEREST_PRECISION, +) from driftpy.math.margin import MarginCategory from driftpy.math.perp_position import calculate_position_pnl - -from tests.dlob_test_constants import mock_perp_markets, mock_spot_markets from driftpy.math.spot_position import get_worst_case_token_amounts from driftpy.oracles.strict_oracle_price import StrictOraclePrice from driftpy.types import SpotBalanceType +from tests.dlob_test_constants import mock_perp_markets, mock_spot_markets + from .helpers import make_mock_user, mock_user_account +mock_pubkey = Mock() + + +def replace_pubkeys_with_mocks(obj): + from solders.pubkey import Pubkey + + if isinstance(obj, Pubkey): + return mock_pubkey + + if hasattr(obj, "__dict__"): + for attr_name, attr_value in obj.__dict__.items(): + if isinstance(attr_value, Pubkey): + setattr(obj, attr_name, mock_pubkey) + elif hasattr(attr_value, "__dict__"): + replace_pubkeys_with_mocks(attr_value) + + return obj + + +for market in mock_perp_markets: + replace_pubkeys_with_mocks(market) + +for market in mock_spot_markets: + replace_pubkeys_with_mocks(market) + @mark.asyncio async def test_empty():