diff --git a/.github/workflows/pull_request.yaml b/.github/workflows/pull_request.yaml index 6e9cfea..4e916a0 100644 --- a/.github/workflows/pull_request.yaml +++ b/.github/workflows/pull_request.yaml @@ -25,3 +25,10 @@ jobs: black --check ./ - name: Type Check (mypy) run: mypy src + - name: Run Pytest + env: + MORALIS_API_KEY: ${{ secrets.MORALIS_API_KEY }} + CHAIN_SLEEP_TIME: ${{ secrets.CHAIN_SLEEP_TIME }} + NODE_URL: ${{ secrets.NODE_URL }} + CHAIN_NAME: ${{ secrets.CHAIN_NAME }} + run: pytest \ No newline at end of file diff --git a/src/price_providers/coingecko_pricing.py b/src/price_providers/coingecko_pricing.py index f296180..efbbbfc 100644 --- a/src/price_providers/coingecko_pricing.py +++ b/src/price_providers/coingecko_pricing.py @@ -14,8 +14,6 @@ COINGECKO_BUFFER_TIME, ) -coingecko_api_key = os.getenv("COINGECKO_API_KEY") - class CoingeckoPriceProvider(AbstractPriceProvider): """ @@ -24,18 +22,26 @@ class CoingeckoPriceProvider(AbstractPriceProvider): def __init__(self) -> None: self.web3 = get_web3_instance() - self.filtered_token_list = self.fetch_coingecko_list() - self.last_reload_time = time.time() # current time in seconds since epoch + self.last_reload_time = time.time() + self.coingecko_api_key = os.getenv("COINGECKO_API_KEY") + try: + self.filtered_token_list = self.fetch_coingecko_list() + self.last_reload_time = time.time() # current time in seconds since epoch + except Exception as e: + logger.warning(f"Failed to fetch initial token list: {e}") @property def name(self) -> str: return "Coingecko" - def fetch_coingecko_list(self) -> list[dict]: + def fetch_coingecko_list(self) -> list[dict] | None: """ Fetch and filter the list of tokens (currently filters only Ethereum) from the Coingecko API. """ + if not self.coingecko_api_key: + return None + url = ( f"https://pro-api.coingecko.com/api/v3/coins/" f"list?include_platform=true&status=active" @@ -43,9 +49,8 @@ def fetch_coingecko_list(self) -> list[dict]: headers = { "accept": "application/json", } - if coingecko_api_key: - headers["x-cg-pro-api-key"] = coingecko_api_key + headers["x-cg-pro-api-key"] = self.coingecko_api_key response = requests.get(url, headers=headers) tokens_list = json.loads(response.text) return [ @@ -71,6 +76,9 @@ def get_token_id_by_address(self, token_address: str) -> str | None: self.last_reload_time = ( time.time() ) # update the last reload time to current time + if not self.filtered_token_list: + return None + for token in self.filtered_token_list: if token["platforms"].get("ethereum") == token_address: return token["id"] @@ -82,7 +90,7 @@ def fetch_api_price( """ Makes call to Coingecko API to fetch price, between a start and end timestamp. """ - if not coingecko_api_key: + if not self.coingecko_api_key: logger.warning("Coingecko API key is not set.") return None # price of token is returned in ETH @@ -92,7 +100,7 @@ def fetch_api_price( ) headers = { "accept": "application/json", - "x-cg-pro-api-key": coingecko_api_key, + "x-cg-pro-api-key": self.coingecko_api_key, } try: response = requests.get(url, headers=headers) @@ -122,6 +130,12 @@ def get_price(self, price_params: dict) -> float | None: Function returns coingecko price for a token address, closest to and at least as large as the block timestamp for a given tx hash. """ + if not self.filtered_token_list: + logger.warning( + "Token list is empty, possibly the Coingecko API key isn't set." + ) + return None + token_address, block_number = extract_params(price_params, is_block=True) block_start_timestamp = self.web3.eth.get_block(block_number)["timestamp"] if self.price_not_retrievable(block_start_timestamp): diff --git a/src/price_providers/moralis_pricing.py b/src/price_providers/moralis_pricing.py index 9cdd968..6cecec5 100644 --- a/src/price_providers/moralis_pricing.py +++ b/src/price_providers/moralis_pricing.py @@ -53,6 +53,6 @@ def get_price(self, price_params: dict) -> float | None: self.logger.warning(f"Error: {e}") except Exception as e: self.logger.warning( - f"Price retrieval for token: {token_address} returned: {e}" + f"Price retrieval for token: {token_address} returned: {e}. Possibly the Moralis API key is missing." ) return None diff --git a/tests/test_fees.py b/tests/test_fees.py new file mode 100644 index 0000000..a5d4ce1 --- /dev/null +++ b/tests/test_fees.py @@ -0,0 +1,28 @@ +import pytest +from hexbytes import HexBytes +from src.fees.compute_fees import batch_fee_imbalances + + +def test_batch_fee_imbalances(): + """ + Test the batch_fee_imbalances function with a valid transaction hash. + """ + tx_hash = "0x714bb3b1a804af7a493bcfa991b9859e03c52387b027783f175255885fa97dbd" + protocol_fees, network_fees = batch_fee_imbalances(HexBytes(tx_hash)) + + # verify that the returned fees are dicts + assert isinstance(protocol_fees, dict), "Protocol fees should be a dict." + assert isinstance(network_fees, dict), "Network fees should be a dict." + + # Check that keys and values in the dict have the correct types + for token, fee in protocol_fees.items(): + assert isinstance(token, str), "Token address should be string." + assert isinstance(fee, int), "Fee amount should be int." + + for token, fee in network_fees.items(): + assert isinstance(token, str), "Token address should be string." + assert isinstance(fee, int), "Fee amount should be int." + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/basic_test.py b/tests/test_imbalances.py similarity index 94% rename from tests/basic_test.py rename to tests/test_imbalances.py index b01ca9e..f5cd49d 100644 --- a/tests/basic_test.py +++ b/tests/test_imbalances.py @@ -53,7 +53,7 @@ def test_imbalances(tx_hash, expected_imbalances): Asserts imbalances match for main script with test values provided. """ chain_name = os.getenv("CHAIN_NAME") - rt = RawTokenImbalances(get_web3_instance(), chain_name) - imbalances = rt.compute_imbalances(tx_hash) + compute = RawTokenImbalances(get_web3_instance(), chain_name) + imbalances = compute.compute_imbalances(tx_hash) for token_address, expected_imbalance in expected_imbalances.items(): assert imbalances.get(token_address) == expected_imbalance diff --git a/tests/test_pricefeed.py b/tests/test_pricefeed.py new file mode 100644 index 0000000..76c198b --- /dev/null +++ b/tests/test_pricefeed.py @@ -0,0 +1,65 @@ +import pytest +from src.price_providers.price_feed import PriceFeed + + +@pytest.fixture +def price_feed(): + return PriceFeed() + + +def test_get_price_real(price_feed): + """Test with legitimate parameters.""" + + # Test parameters + tx_hash = "0x94af3d98b0af4ca6bf41e85c05ed42fccd71d5aaa04cbe01fab00d1b2268c4e1" + token_address = "0xd1d2Eb1B1e90B638588728b4130137D262C87cae" + block_number = 20630508 + price_params = { + "tx_hash": tx_hash, + "token_address": token_address, + "block_number": block_number, + } + + # Get the price + result = price_feed.get_price(price_params) + assert result is not None + + price, source = result + # Assert that the price is a positive float + assert isinstance(price, float) + assert price > 0 + assert source in ["Coingecko", "Dune", "Moralis", "AuctionPrices"] + + +def test_get_price_unknown_token(price_feed): + """Test with an unknown token address.""" + + tx_hash = "0x94af3d98b0af4ca6bf41e85c05ed42fccd71d5aaa04cbe01fab00d1b2268c4e1" + unknown_token = "0xd1d2Eb1B1e90B638588728b4130137D262C87cad" + price_params = { + "tx_hash": tx_hash, + "token_address": unknown_token, + "block_number": 20630508, + } + result = price_feed.get_price(price_params) + + # expect None for an unknown token + assert result is None + + +def test_get_price_future_block(price_feed): + """Test with a block number in the future.""" + future_block = 99999999 + price_params = { + "token_address": "0x6B175474E89094C44Da98b954EedeAC495271d0F", + "block_number": future_block, + } + + result = price_feed.get_price(price_params) + + # expect None for a future block + assert result is None + + +if __name__ == "__main__": + pytest.main()